use ndarray::{
Array1, Array2, Array3, Array4, Array5, ArrayView1, ArrayView2, ArrayView3, ArrayView4, s,
};
use std::sync::Arc;
use crate::solver::arrow_schur::{
ArrowProximalCorrectionOptions, ArrowRowBlock, ArrowSchurError, ArrowSchurSystem,
ArrowSolveOptions, BetaPenaltyOp, CompositePenaltyOp, DensePenaltyOp, FactoredFrameGBlock,
FactoredFrameKroneckerOp, IdentityRightKroneckerPenaltyOp, SparseBlockKroneckerPenaltyOp,
SparseGBlock, StreamingArrowSchur, solve_arrow_newton_step_with_proximal_correction,
solve_streaming_reduced_beta,
};
use crate::terms::analytic_penalties::{
AnalyticPenalty, AnalyticPenaltyKind, AnalyticPenaltyRegistry, DecoderIncoherencePenalty,
IBPAssignmentPenalty, IbpHessianDiagThirdChannels, IsometryPenalty, MechanismSparsityPenalty,
NuclearNormPenalty, PenaltyTier, PsiSlice, SoftmaxAssignmentSparsityPenalty, WeightField,
resolve_learnable_weight,
};
use crate::terms::latent_coord::{LatentCoordValues, LatentIdMode, LatentManifold};
use crate::terms::sae_criterion_atoms::SaeCriterion;
use crate::terms::sae_optimality_certificate::{
CriterionCertificate, DirectionalSamples, certificate_from_samples,
deterministic_probe_direction, probe_step,
};
use crate::linalg::faer_ndarray::{FaerEigh, FaerSvd, fast_ab, fast_abt, fast_atb};
use crate::solver::arrow_schur::{
ArrowFactorCache, arrow_factor_max_pivot, arrow_factor_min_pivot,
solve_arrow_newton_step_with_options,
};
use crate::solver::estimate::EstimationError;
use crate::solver::evidence::arrow_log_det_from_cache;
use crate::solver::outer_strategy::{
DeclaredHessianForm, Derivative, EfsEval, HessianResult, OuterCapability, OuterEval,
OuterObjective, SeedOutcome,
};
use crate::solver::structure_search::{CollapseAction, CollapseEvent};
use faer::Side;
const SAE_MANIFOLD_ARMIJO_C1: f64 = 1.0e-4;
const SAE_MANIFOLD_MAX_LINESEARCH_HALVINGS: usize = 12;
const SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR: f64 = 1.0e-12;
const CURVATURE_WALK_INITIAL_ETA_STEP: f64 = 0.2;
const CURVATURE_WALK_MIN_ETA_STEP: f64 = 1.0 / 256.0;
const CURVATURE_WALK_MAX_CORRECTORS: usize = 32;
const SAE_MANIFOLD_DIRECTIONAL_DECREASE_REL_FLOOR: f64 = 1.0e-14;
const SAE_MANIFOLD_INNER_STEP_REL_TOL: f64 = 1.0e-4;
const SAE_MANIFOLD_INNER_GRAD_REL_TOL: f64 = 1.0e-5;
const SAE_DENSE_BETA_PENALTY_PROBE_MAX_DIM: usize = 4096;
const SAE_MANIFOLD_SPECTRAL_RANK_CUTOFF: f64 = 1.0e-9;
const SAE_MANIFOLD_ROW_RIDGE_FLOOR: f64 = 1.0e-12;
const SAE_MANIFOLD_ROW_RIDGE_GROWTH: f64 = 10.0;
const SAE_MANIFOLD_ROW_RIDGE_MAX_ATTEMPTS: usize = 12;
#[derive(Clone, Copy, Debug, Default)]
struct SaeBetaPenaltyAssembly {
dense_written: bool,
deferred_factored: bool,
}
impl SaeBetaPenaltyAssembly {
fn record_curvature(&mut self, dense_beta_curvature: bool) {
if dense_beta_curvature {
self.dense_written = true;
} else {
self.deferred_factored = true;
}
}
}
#[derive(Clone, Debug)]
struct FrameProjection {
p: usize,
beta_offsets: Vec<usize>,
border_offsets: Vec<usize>,
basis_sizes: Vec<usize>,
ranks: Vec<usize>,
frames: Vec<Option<Array2<f64>>>,
}
impl FrameProjection {
fn new(term: &SaeManifoldTerm) -> Self {
Self {
p: term.output_dim(),
beta_offsets: term.beta_offsets(),
border_offsets: term.factored_border_offsets(),
basis_sizes: term.atoms.iter().map(|atom| atom.basis_size()).collect(),
ranks: term
.atoms
.iter()
.map(|atom| atom.border_frame_rank())
.collect(),
frames: term
.atoms
.iter()
.map(|atom| {
atom.decoder_frame
.as_ref()
.map(|frame| frame.frame().to_owned())
})
.collect(),
}
}
fn beta_dim(&self) -> usize {
self.basis_sizes.iter().sum::<usize>() * self.p
}
fn border_dim(&self) -> usize {
self.basis_sizes
.iter()
.zip(&self.ranks)
.map(|(m, r)| m * r)
.sum()
}
fn lift_border_vec(&self, border: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.beta_dim());
for atom in 0..self.basis_sizes.len() {
self.lift_atom_vec_into(atom, border, out.view_mut());
}
out
}
fn project_border_vec(&self, beta: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.border_dim());
for atom in 0..self.basis_sizes.len() {
self.project_atom_vec_into(atom, beta, out.view_mut(), 1.0);
}
out
}
fn lift_block(&self, atom: usize, block: ArrayView2<'_, f64>) -> Array2<f64> {
let m = self.basis_sizes[atom];
let r = self.ranks[atom];
if self.frames[atom].is_none() {
return block.to_owned();
}
let uk = self.frames[atom].as_ref().expect("framed atom has a frame");
let mut out = Array2::<f64>::zeros((m * self.p, m * self.p));
for b1 in 0..m {
for b2 in 0..m {
for c1 in 0..self.p {
for c2 in 0..self.p {
let mut acc = 0.0;
for j1 in 0..r {
for j2 in 0..r {
acc +=
uk[[c1, j1]] * block[[b1 * r + j1, b2 * r + j2]] * uk[[c2, j2]];
}
}
out[[b1 * self.p + c1, b2 * self.p + c2]] = acc;
}
}
}
}
out
}
fn project_block(&self, hbb: ArrayView2<'_, f64>) -> Array2<f64> {
let t = self.project_rows(hbb);
let mut out = Array2::<f64>::zeros((self.border_dim(), self.border_dim()));
for atom in 0..self.basis_sizes.len() {
self.project_block_left_atom(atom, t.view(), out.view_mut());
}
out
}
fn project_rows(&self, block: ArrayView2<'_, f64>) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((block.nrows(), self.border_dim()));
for row in 0..block.nrows() {
let projected = self.project_border_vec(block.row(row));
out.row_mut(row).assign(&projected);
}
out
}
fn atom_border_range(&self, atom: usize) -> std::ops::Range<usize> {
let start = self.border_offsets[atom];
start..start + self.basis_sizes[atom] * self.ranks[atom]
}
fn lift_axis_into(
&self,
out: &mut Array1<f64>,
atom: usize,
basis_col: usize,
frame_col: usize,
) {
let base = self.beta_offsets[atom] + basis_col * self.p;
match &self.frames[atom] {
None => out[base + frame_col] = 1.0,
Some(uk) => {
for out_col in 0..self.p {
out[base + out_col] = uk[[out_col, frame_col]];
}
}
}
}
fn lift_local_axis_into(
&self,
out: &mut Array1<f64>,
atom: usize,
basis_col: usize,
frame_col: usize,
) {
let base = basis_col * self.p;
match &self.frames[atom] {
None => out[base + frame_col] = 1.0,
Some(uk) => {
for out_col in 0..self.p {
out[base + out_col] = uk[[out_col, frame_col]];
}
}
}
}
fn project_atom_vec_into(
&self,
atom: usize,
beta: ArrayView1<'_, f64>,
mut out: ndarray::ArrayViewMut1<'_, f64>,
scale: f64,
) {
let m = self.basis_sizes[atom];
let r = self.ranks[atom];
let ob = self.beta_offsets[atom];
let oc = self.border_offsets[atom];
for basis_col in 0..m {
let base_b = ob + basis_col * self.p;
let base_c = oc + basis_col * r;
match &self.frames[atom] {
None => {
for j in 0..r {
out[base_c + j] += scale * beta[base_b + j];
}
}
Some(uk) => {
for j in 0..r {
let mut acc = 0.0;
for i in 0..self.p {
acc += uk[[i, j]] * beta[base_b + i];
}
out[base_c + j] += scale * acc;
}
}
}
}
}
fn project_local_atom_vec_into(
&self,
atom: usize,
beta: ArrayView1<'_, f64>,
out: ndarray::ArrayViewMut1<'_, f64>,
scale: f64,
) {
self.project_atom_vec_into_with_base(atom, beta, out, scale, 0);
}
fn project_atom_vec_into_with_base(
&self,
atom: usize,
beta: ArrayView1<'_, f64>,
mut out: ndarray::ArrayViewMut1<'_, f64>,
scale: f64,
beta_base_offset: usize,
) {
let m = self.basis_sizes[atom];
let r = self.ranks[atom];
let oc = self.border_offsets[atom];
for basis_col in 0..m {
let base_b = beta_base_offset + basis_col * self.p;
let base_c = oc + basis_col * r;
match &self.frames[atom] {
None => {
for j in 0..r {
out[base_c + j] += scale * beta[base_b + j];
}
}
Some(uk) => {
for j in 0..r {
let mut acc = 0.0;
for i in 0..self.p {
acc += uk[[i, j]] * beta[base_b + i];
}
out[base_c + j] += scale * acc;
}
}
}
}
}
fn lift_atom_vec_into(
&self,
atom: usize,
border: ArrayView1<'_, f64>,
mut out: ndarray::ArrayViewMut1<'_, f64>,
) {
let m = self.basis_sizes[atom];
let r = self.ranks[atom];
let ob = self.beta_offsets[atom];
let oc = self.border_offsets[atom];
for basis_col in 0..m {
let base_b = ob + basis_col * self.p;
let base_c = oc + basis_col * r;
match &self.frames[atom] {
None => {
for i in 0..self.p {
out[base_b + i] = border[base_c + i];
}
}
Some(uk) => {
for i in 0..self.p {
let mut acc = 0.0;
for j in 0..r {
acc += uk[[i, j]] * border[base_c + j];
}
out[base_b + i] = acc;
}
}
}
}
}
fn accumulate_row_lift(
&self,
atom: usize,
c_base: usize,
phi: f64,
x: &[f64],
out: &mut [f64],
) {
match &self.frames[atom] {
None => {
for i in 0..self.p {
out[i] += phi * x[c_base + i];
}
}
Some(uk) => {
for i in 0..self.p {
let mut acc = 0.0;
for j in 0..self.ranks[atom] {
acc += uk[[i, j]] * x[c_base + j];
}
out[i] += phi * acc;
}
}
}
}
fn accumulate_row_project(
&self,
atom: usize,
c_base: usize,
phi: f64,
u: &[f64],
out: &mut [f64],
) {
match &self.frames[atom] {
None => {
for i in 0..self.p {
out[c_base + i] += phi * u[i];
}
}
Some(uk) => {
for j in 0..self.ranks[atom] {
let mut acc = 0.0;
for i in 0..self.p {
acc += uk[[i, j]] * u[i];
}
out[c_base + j] += phi * acc;
}
}
}
}
fn accumulate_output_project(
&self,
atom: usize,
c_base: usize,
output: usize,
value: f64,
out: &mut [f64],
) {
match &self.frames[atom] {
None => out[c_base + output] += value,
Some(uk) => {
for j in 0..self.ranks[atom] {
out[c_base + j] += value * uk[[output, j]];
}
}
}
}
fn output_variance(
&self,
atom: usize,
cov_c: ArrayView2<'_, f64>,
basis: ArrayView1<'_, f64>,
output: usize,
) -> f64 {
let Some(uk) = &self.frames[atom] else {
return self.full_output_variance(atom, cov_c, basis, output);
};
let m = self.basis_sizes[atom];
let r = self.ranks[atom];
let mut var = 0.0;
for b1 in 0..m {
let phi1 = basis[b1];
if phi1 == 0.0 {
continue;
}
for b2 in 0..m {
let phi2 = basis[b2];
if phi2 == 0.0 {
continue;
}
for j1 in 0..r {
for j2 in 0..r {
var += phi1
* phi2
* uk[[output, j1]]
* cov_c[[b1 * r + j1, b2 * r + j2]]
* uk[[output, j2]];
}
}
}
}
var
}
fn full_output_variance(
&self,
atom: usize,
cov: ArrayView2<'_, f64>,
basis: ArrayView1<'_, f64>,
output: usize,
) -> f64 {
let m = self.basis_sizes[atom];
let mut var = 0.0;
for b1 in 0..m {
let phi1 = basis[b1];
if phi1 == 0.0 {
continue;
}
for b2 in 0..m {
var += phi1 * basis[b2] * cov[[b1 * self.p + output, b2 * self.p + output]];
}
}
var
}
fn project_block_left_atom(
&self,
atom: usize,
t: ArrayView2<'_, f64>,
mut out: ndarray::ArrayViewMut2<'_, f64>,
) {
let m = self.basis_sizes[atom];
let r = self.ranks[atom];
let ob = self.beta_offsets[atom];
let oc = self.border_offsets[atom];
for basis_col in 0..m {
let base_b = ob + basis_col * self.p;
let base_c = oc + basis_col * r;
match &self.frames[atom] {
None => {
for j in 0..r {
for c in 0..out.ncols() {
out[[base_c + j, c]] += t[[base_b + j, c]];
}
}
}
Some(uk) => {
for j in 0..r {
for c in 0..out.ncols() {
let mut acc = 0.0;
for i in 0..self.p {
acc += uk[[i, j]] * t[[base_b + i, c]];
}
out[[base_c + j, c]] += acc;
}
}
}
}
}
}
}
const SAE_ASSIGNMENT_LOGIT_STEP_CAP_TAUS: f64 = 4.0;
const SAE_ATOM_ACTIVE_MASS_FLOOR: f64 = 1.0e-3;
const SAE_ATOM_COLLAPSE_RESEED_BUDGET: usize = 1;
const JUMPRELU_REACTIVATION_MARGIN: f64 = 4.0;
#[inline]
fn jumprelu_in_optimization_band(logit: f64, threshold: f64, temperature: f64) -> bool {
logit > threshold - JUMPRELU_REACTIVATION_MARGIN * temperature
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SaeStreamingPlan {
pub streaming: bool,
pub chunk_size: usize,
pub estimated_full_batch_bytes: usize,
pub in_core_budget_bytes: usize,
}
fn sae_streaming_plan_from_budget(
n_obs: usize,
total_basis: usize,
k_atoms: usize,
d_max: usize,
in_core_budget_bytes: usize,
chunk_window_bytes: usize,
) -> SaeStreamingPlan {
const BYTES_PER_F64: usize = 8;
const MIN_CHUNK_ROWS: usize = 256;
let per_row_words = total_basis
.saturating_mul(1 + d_max)
.saturating_add(k_atoms)
.max(1);
let per_row_bytes = per_row_words.saturating_mul(BYTES_PER_F64);
let full_batch_bytes = n_obs.saturating_mul(per_row_bytes);
if full_batch_bytes <= in_core_budget_bytes {
return SaeStreamingPlan {
streaming: false,
chunk_size: n_obs.max(1),
estimated_full_batch_bytes: full_batch_bytes,
in_core_budget_bytes,
};
}
let rows_per_chunk = (chunk_window_bytes / per_row_bytes).max(MIN_CHUNK_ROWS);
SaeStreamingPlan {
streaming: true,
chunk_size: rows_per_chunk.min(n_obs).max(1),
estimated_full_batch_bytes: full_batch_bytes,
in_core_budget_bytes,
}
}
#[derive(Debug, Clone)]
pub enum ScheduleKind {
Geometric { rate: f64 },
Linear { steps: usize },
ReciprocalIter,
}
#[derive(Debug, Clone)]
pub struct GumbelTemperatureSchedule {
pub tau_start: f64,
pub tau_min: f64,
pub decay: ScheduleKind,
pub iter_count: usize,
}
impl GumbelTemperatureSchedule {
#[must_use = "build error must be handled"]
pub fn new(tau_start: f64, tau_min: f64, decay: ScheduleKind) -> Result<Self, String> {
let sched = Self {
tau_start,
tau_min,
decay,
iter_count: 0,
};
sched.validate()?;
Ok(sched)
}
pub fn validate(&self) -> Result<(), String> {
if !(self.tau_start.is_finite() && self.tau_start > 0.0) {
return Err(format!(
"GumbelTemperatureSchedule: tau_start must be finite and positive; got {}",
self.tau_start
));
}
if !(self.tau_min.is_finite() && self.tau_min > 0.0) {
return Err(format!(
"GumbelTemperatureSchedule: tau_min must be finite and positive; got {}",
self.tau_min
));
}
if self.tau_min > self.tau_start {
return Err(format!(
"GumbelTemperatureSchedule: tau_min ({}) cannot exceed tau_start ({})",
self.tau_min, self.tau_start
));
}
match self.decay {
ScheduleKind::Geometric { rate } => {
if !(rate.is_finite() && rate > 0.0 && rate < 1.0) {
return Err(format!(
"GumbelTemperatureSchedule::Geometric: rate must be in (0, 1); got {rate}"
));
}
}
ScheduleKind::Linear { steps } => {
if steps == 0 {
return Err("GumbelTemperatureSchedule::Linear: steps must be positive".into());
}
}
ScheduleKind::ReciprocalIter => {}
}
Ok(())
}
pub fn current_tau(&self, iter: usize) -> f64 {
let raw = match self.decay {
ScheduleKind::Geometric { rate } => self.tau_start * rate.powf(iter as f64),
ScheduleKind::Linear { steps } => {
if iter >= steps {
self.tau_min
} else {
let frac = iter as f64 / steps as f64;
self.tau_start + frac * (self.tau_min - self.tau_start)
}
}
ScheduleKind::ReciprocalIter => self.tau_start / (1.0 + iter as f64),
};
raw.max(self.tau_min)
}
pub fn step(&mut self) -> f64 {
let tau = self.current_tau(self.iter_count);
self.iter_count += 1;
tau
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum SearchStrategy {
Fixed,
ExponentialSweep { values: Vec<f64> },
}
impl SearchStrategy {
#[must_use]
pub fn is_fixed(&self) -> bool {
matches!(self, Self::Fixed)
}
#[must_use]
pub fn sweep_values(&self) -> Option<&[f64]> {
match self {
Self::Fixed => None,
Self::ExponentialSweep { values } => Some(values),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SaeAtomBasisKind {
Duchon,
Periodic,
Sphere,
Torus,
EuclideanPatch,
Precomputed(String),
}
impl SaeAtomBasisKind {
fn latent_manifold(&self, latent_dim: usize) -> LatentManifold {
match self {
Self::Periodic => {
if latent_dim == 1 {
LatentManifold::Circle { period: 1.0 }
} else {
LatentManifold::Product(
(0..latent_dim)
.map(|_| LatentManifold::Circle { period: 1.0 })
.collect(),
)
}
}
Self::Sphere => LatentManifold::Product(vec![
LatentManifold::Interval {
lo: -std::f64::consts::FRAC_PI_2,
hi: std::f64::consts::FRAC_PI_2,
},
LatentManifold::Circle {
period: std::f64::consts::TAU,
},
]),
Self::Torus => {
if latent_dim == 1 {
LatentManifold::Circle { period: 1.0 }
} else {
LatentManifold::Product(
(0..latent_dim)
.map(|_| LatentManifold::Circle { period: 1.0 })
.collect(),
)
}
}
Self::Duchon | Self::EuclideanPatch | Self::Precomputed(_) => LatentManifold::Euclidean,
}
}
fn projection_seed_grid(&self, latent_dim: usize, resolution: usize) -> Option<Array2<f64>> {
match self {
Self::Periodic => torus_projection_seed_grid(latent_dim, resolution),
Self::Sphere if latent_dim == 2 => sphere_projection_seed_grid(resolution),
Self::Sphere => None,
Self::Torus => torus_projection_seed_grid(latent_dim, resolution),
Self::Duchon | Self::EuclideanPatch | Self::Precomputed(_) => None,
}
}
}
fn sphere_projection_seed_grid(resolution: usize) -> Option<Array2<f64>> {
use std::f64::consts::PI;
let r = resolution.max(2);
let mut grid = Array2::<f64>::zeros((r * r, 2));
for i in 0..r {
let lat = -PI / 2.0 + PI * (i as f64 + 0.5) / r as f64;
for j in 0..r {
let lon = -PI + 2.0 * PI * (j as f64) / r as f64;
grid[[i * r + j, 0]] = lat;
grid[[i * r + j, 1]] = lon;
}
}
Some(grid)
}
fn torus_projection_seed_grid(latent_dim: usize, resolution: usize) -> Option<Array2<f64>> {
if latent_dim == 0 || latent_dim >= usize::BITS as usize {
return None;
}
const MAX_GRID_POINTS: usize = 4096;
let min_points = 1usize << latent_dim;
if min_points > MAX_GRID_POINTS {
return None;
}
let requested = resolution.max(2);
let mut per_axis = requested;
while per_axis.saturating_pow(latent_dim as u32) > MAX_GRID_POINTS {
per_axis -= 1;
if per_axis < 2 {
return None;
}
}
let total: usize = (0..latent_dim).fold(1usize, |acc, _| acc.saturating_mul(per_axis));
let mut grid = Array2::<f64>::zeros((total, latent_dim));
let mut idx = vec![0usize; latent_dim];
for flat in 0..total {
for axis in 0..latent_dim {
grid[[flat, axis]] = idx[axis] as f64 / per_axis as f64;
}
for axis in (0..latent_dim).rev() {
idx[axis] += 1;
if idx[axis] < per_axis {
break;
}
idx[axis] = 0;
}
}
Some(grid)
}
#[derive(Clone, Copy, Debug)]
struct ArdAxisPrior {
value: f64,
grad: f64,
hess: f64,
sq_equiv: f64,
}
impl ArdAxisPrior {
fn eval(alpha: f64, t: f64, period: Option<f64>) -> Self {
match period {
None => Self {
value: 0.5 * alpha * t * t,
grad: alpha * t,
hess: alpha,
sq_equiv: t * t,
},
Some(p) => {
let kappa = std::f64::consts::TAU / p;
let (sin, cos) = (kappa * t).sin_cos();
let one_minus_cos = 1.0 - cos;
Self {
value: (alpha / (kappa * kappa)) * one_minus_cos,
grad: (alpha / kappa) * sin,
hess: alpha * cos,
sq_equiv: (2.0 / (kappa * kappa)) * one_minus_cos,
}
}
}
}
}
fn bessel_i0(x: f64) -> f64 {
let ax = x.abs();
if ax < 3.75 {
let t = x / 3.75;
let t2 = t * t;
1.0 + t2
* (3.5156229
+ t2 * (3.0899424
+ t2 * (1.2067492 + t2 * (0.2659732 + t2 * (0.0360768 + t2 * 0.0045813)))))
} else {
let y = 3.75 / ax;
let poly = 0.39894228
+ y * (0.01328592
+ y * (0.00225319
+ y * (-0.00157565
+ y * (0.00916281
+ y * (-0.02057706
+ y * (0.02635537 + y * (-0.01647633 + y * 0.00392377)))))));
(ax.exp() / ax.sqrt()) * poly
}
}
fn bessel_i1(x: f64) -> f64 {
let ax = x.abs();
let value = if ax < 3.75 {
let t = x / 3.75;
let t2 = t * t;
ax * (0.5
+ t2 * (0.87890594
+ t2 * (0.51498869
+ t2 * (0.15084934 + t2 * (0.02658733 + t2 * (0.00301532 + t2 * 0.00032411))))))
} else {
let y = 3.75 / ax;
let poly = 0.39894228
+ y * (-0.03988024
+ y * (-0.00362018
+ y * (0.00163801
+ y * (-0.01031555
+ y * (0.02282967
+ y * (-0.02895312 + y * (0.01787654 - y * 0.00420059)))))));
(ax.exp() / ax.sqrt()) * poly
};
if x < 0.0 { -value } else { value }
}
pub trait SaeBasisEvaluator: Send + Sync + std::fmt::Debug {
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String>;
fn affine_transformed_evaluator(
&self,
shift: &[f64],
scale: &[f64],
n_basis: usize,
) -> Result<Option<Arc<dyn SaeBasisSecondJet>>, String> {
if shift.len() == usize::MAX || scale.len() == usize::MAX || n_basis == usize::MAX {
return Err("SaeBasisEvaluator::affine_transformed_evaluator: unreachable affine metadata width".to_string());
}
Ok(None)
}
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
Ok(PhiEtaSplit::all_linear(n_basis))
}
fn evaluate_phi_eta(
&self,
coords: ArrayView2<'_, f64>,
eta: f64,
) -> Result<PhiEtaEvaluation, String> {
if !(eta.is_finite() && (0.0..=1.0).contains(&eta)) {
return Err(format!(
"SaeBasisEvaluator::evaluate_phi_eta: eta must be finite in [0, 1]; got {eta}"
));
}
let (mut phi, mut jet) = self.evaluate(coords)?;
let split = self.phi_eta_split(phi.ncols())?;
let mut dphi_deta = Array2::<f64>::zeros(phi.dim());
let mut djet_deta = Array3::<f64>::zeros(jet.dim());
for &col in &split.curved_cols {
if col >= phi.ncols() {
return Err(format!(
"SaeBasisEvaluator::evaluate_phi_eta: curved column {col} exceeds basis width {}",
phi.ncols()
));
}
for row in 0..phi.nrows() {
dphi_deta[[row, col]] = phi[[row, col]];
if eta != 1.0 {
phi[[row, col]] *= eta;
}
for axis in 0..jet.shape()[2] {
djet_deta[[row, col, axis]] = jet[[row, col, axis]];
if eta != 1.0 {
jet[[row, col, axis]] *= eta;
}
}
}
}
Ok(PhiEtaEvaluation {
phi,
jet,
dphi_deta,
djet_deta,
split,
})
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>>;
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>>;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PhiEtaSplit {
pub linear_cols: Vec<usize>,
pub curved_cols: Vec<usize>,
}
impl PhiEtaSplit {
pub fn all_linear(n_basis: usize) -> Self {
Self {
linear_cols: (0..n_basis).collect(),
curved_cols: Vec::new(),
}
}
fn from_curved_mask(mask: Vec<bool>) -> Self {
let mut linear_cols = Vec::new();
let mut curved_cols = Vec::new();
for (col, curved) in mask.into_iter().enumerate() {
if curved {
curved_cols.push(col);
} else {
linear_cols.push(col);
}
}
Self {
linear_cols,
curved_cols,
}
}
}
#[derive(Debug, Clone)]
pub struct PhiEtaEvaluation {
pub phi: Array2<f64>,
pub jet: Array3<f64>,
pub dphi_deta: Array2<f64>,
pub djet_deta: Array3<f64>,
pub split: PhiEtaSplit,
}
fn monomial_linear_mask(dimension: usize, max_total_degree: usize) -> Vec<bool> {
crate::basis::monomial_exponents(dimension, max_total_degree)
.iter()
.map(|alpha| alpha.iter().sum::<usize>() <= 1)
.collect()
}
fn duchon_effective_order_for_eta(
centers: ArrayView2<'_, f64>,
order: crate::basis::DuchonNullspaceOrder,
) -> crate::basis::DuchonNullspaceOrder {
let mut effective = order;
while effective != crate::basis::DuchonNullspaceOrder::Zero
&& centers.nrows() <= duchon_polynomial_column_count(centers.ncols(), effective)
{
effective = match effective {
crate::basis::DuchonNullspaceOrder::Zero => crate::basis::DuchonNullspaceOrder::Zero,
crate::basis::DuchonNullspaceOrder::Linear => crate::basis::DuchonNullspaceOrder::Zero,
crate::basis::DuchonNullspaceOrder::Degree(2) => {
crate::basis::DuchonNullspaceOrder::Linear
}
crate::basis::DuchonNullspaceOrder::Degree(k) => {
crate::basis::DuchonNullspaceOrder::Degree(k - 1)
}
};
}
effective
}
fn duchon_polynomial_column_count(
dimension: usize,
order: crate::basis::DuchonNullspaceOrder,
) -> usize {
match order {
crate::basis::DuchonNullspaceOrder::Zero => 1,
crate::basis::DuchonNullspaceOrder::Linear => dimension + 1,
crate::basis::DuchonNullspaceOrder::Degree(degree) => {
crate::basis::monomial_exponents(dimension, degree).len()
}
}
}
pub trait SaeBasisSecondJet: SaeBasisEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String>;
}
pub trait SaeBasisThirdJet: SaeBasisSecondJet {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String>;
}
#[derive(Debug, Clone)]
pub struct PeriodicHarmonicEvaluator {
pub num_basis: usize,
}
impl PeriodicHarmonicEvaluator {
pub fn new(num_basis: usize) -> Result<Self, String> {
if num_basis == 0 || num_basis % 2 == 0 {
return Err(format!(
"PeriodicHarmonicEvaluator requires odd num_basis >= 1; got {num_basis}"
));
}
Ok(Self { num_basis })
}
}
impl SaeBasisEvaluator for PeriodicHarmonicEvaluator {
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
if n_basis != self.num_basis {
return Err(format!(
"PeriodicHarmonicEvaluator::phi_eta_split: n_basis {n_basis} != evaluator width {}",
self.num_basis
));
}
let mut curved = vec![false; n_basis];
for h in 2..=(n_basis - 1) / 2 {
curved[2 * h - 1] = true;
curved[2 * h] = true;
}
Ok(PhiEtaSplit::from_curved_mask(curved))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
Some(<Self as SaeBasisThirdJet>::third_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
let n = coords.nrows();
let d = coords.ncols();
if d != 1 {
return Err(format!(
"PeriodicHarmonicEvaluator: expected latent_dim == 1, got {d}"
));
}
let m = self.num_basis;
let num_harmonics = (m - 1) / 2;
let two_pi = 2.0 * std::f64::consts::PI;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, 1));
for row in 0..n {
let t = coords[[row, 0]];
phi[[row, 0]] = 1.0;
for h in 1..=num_harmonics {
let angle = two_pi * (h as f64) * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * h - 1;
let c_idx = 2 * h;
phi[[row, s_idx]] = s;
phi[[row, c_idx]] = c;
jet[[row, s_idx, 0]] = two_pi * (h as f64) * c;
jet[[row, c_idx, 0]] = -two_pi * (h as f64) * s;
}
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for PeriodicHarmonicEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
let n = coords.nrows();
let d = coords.ncols();
if d != 1 {
return Err(format!(
"PeriodicHarmonicEvaluator::second_jet: expected latent_dim == 1, got {d}"
));
}
let m = self.num_basis;
let num_harmonics = (m - 1) / 2;
let two_pi = 2.0 * std::f64::consts::PI;
let mut h = Array4::<f64>::zeros((n, m, 1, 1));
for row in 0..n {
let t = coords[[row, 0]];
for k in 1..=num_harmonics {
let freq = two_pi * (k as f64);
let freq2 = freq * freq;
let angle = freq * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * k - 1;
let c_idx = 2 * k;
h[[row, s_idx, 0, 0]] = -freq2 * s;
h[[row, c_idx, 0, 0]] = -freq2 * c;
}
}
Ok(h)
}
}
impl SaeBasisThirdJet for PeriodicHarmonicEvaluator {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String> {
let n = coords.nrows();
let d = coords.ncols();
if d != 1 {
return Err(format!(
"PeriodicHarmonicEvaluator::third_jet: expected latent_dim == 1, got {d}"
));
}
let m = self.num_basis;
let num_harmonics = (m - 1) / 2;
let two_pi = 2.0 * std::f64::consts::PI;
let mut t3 = Array5::<f64>::zeros((n, m, 1, 1, 1));
for row in 0..n {
let t = coords[[row, 0]];
for k in 1..=num_harmonics {
let freq = two_pi * (k as f64);
let freq3 = freq * freq * freq;
let angle = freq * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * k - 1;
let c_idx = 2 * k;
t3[[row, s_idx, 0, 0, 0]] = -freq3 * c;
t3[[row, c_idx, 0, 0, 0]] = freq3 * s;
}
}
Ok(t3)
}
}
#[derive(Debug, Clone)]
pub struct RawPeriodicCircleEvaluator {
pub latent_dim: usize,
}
impl RawPeriodicCircleEvaluator {
pub fn new(latent_dim: usize) -> Result<Self, String> {
if latent_dim == 0 {
return Err("RawPeriodicCircleEvaluator requires latent_dim >= 1".to_string());
}
Ok(Self { latent_dim })
}
}
impl SaeBasisEvaluator for RawPeriodicCircleEvaluator {
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
if n_basis != 2 {
return Err(format!(
"RawPeriodicCircleEvaluator::phi_eta_split: n_basis {n_basis} != 2"
));
}
Ok(PhiEtaSplit::all_linear(n_basis))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
if coords.ncols() != self.latent_dim {
return Some(Err(format!(
"RawPeriodicCircleEvaluator::second_jet_dyn: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
)));
}
None
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
if coords.ncols() != self.latent_dim {
return Some(Err(format!(
"RawPeriodicCircleEvaluator::third_jet_dyn: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
)));
}
None
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"RawPeriodicCircleEvaluator: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 2));
let mut jet = Array3::<f64>::zeros((n, 2, self.latent_dim));
for row in 0..n {
let t = coords[[row, 0]];
phi[[row, 0]] = t.cos();
phi[[row, 1]] = t.sin();
jet[[row, 0, 0]] = -t.sin();
jet[[row, 1, 0]] = t.cos();
}
Ok((phi, jet))
}
}
pub const SPHERE_CHART_PENALTY_DIAGONAL: [f64; 7] = [1e-8, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0];
pub fn sphere_chart_basis_jet(
coords: ArrayView2<'_, f64>,
) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != 2 {
return Err(format!(
"sphere_chart_basis_jet expects latent_dim == 2, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 7));
let mut jet = Array3::<f64>::zeros((n, 7, 2));
for row in 0..n {
let lat = coords[[row, 0]];
let lon = coords[[row, 1]];
let clat = lat.cos();
let slat = lat.sin();
let clon = lon.cos();
let slon = lon.sin();
let x = clat * clon;
let y = clat * slon;
let z = slat;
phi[[row, 0]] = 1.0;
phi[[row, 1]] = x;
phi[[row, 2]] = y;
phi[[row, 3]] = z;
phi[[row, 4]] = x * y;
phi[[row, 5]] = y * z;
phi[[row, 6]] = x * z;
let dx_dlat = -slat * clon;
let dx_dlon = -clat * slon;
let dy_dlat = -slat * slon;
let dy_dlon = clat * clon;
let dz_dlat = clat;
jet[[row, 1, 0]] = dx_dlat;
jet[[row, 1, 1]] = dx_dlon;
jet[[row, 2, 0]] = dy_dlat;
jet[[row, 2, 1]] = dy_dlon;
jet[[row, 3, 0]] = dz_dlat;
jet[[row, 4, 0]] = dx_dlat * y + x * dy_dlat;
jet[[row, 4, 1]] = dx_dlon * y + x * dy_dlon;
jet[[row, 5, 0]] = dy_dlat * z + y * dz_dlat;
jet[[row, 5, 1]] = dy_dlon * z;
jet[[row, 6, 0]] = dx_dlat * z + x * dz_dlat;
jet[[row, 6, 1]] = dx_dlon * z;
}
Ok((phi, jet))
}
#[derive(Debug, Clone)]
pub struct SphereChartEvaluator;
impl SaeBasisEvaluator for SphereChartEvaluator {
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
if n_basis != 7 {
return Err(format!(
"SphereChartEvaluator::phi_eta_split: n_basis {n_basis} != 7"
));
}
let mut curved = vec![false; n_basis];
for col in 4..7 {
curved[col] = true;
}
Ok(PhiEtaSplit::from_curved_mask(curved))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
Some(<Self as SaeBasisThirdJet>::third_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
sphere_chart_basis_jet(coords)
}
}
impl SaeBasisSecondJet for SphereChartEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
if coords.ncols() != 2 {
return Err(format!(
"SphereChartEvaluator::second_jet expects latent_dim == 2, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let mut h = Array4::<f64>::zeros((n, 7, 2, 2));
for row in 0..n {
let lat = coords[[row, 0]];
let lon = coords[[row, 1]];
let clat = lat.cos();
let slat = lat.sin();
let clon = lon.cos();
let slon = lon.sin();
let x = clat * clon;
let y = clat * slon;
let z = slat;
let dx = [-slat * clon, -clat * slon];
let dy = [-slat * slon, clat * clon];
let dz = [clat, 0.0];
let hx = [[-x, slat * slon], [slat * slon, -x]];
let hy = [[-y, -slat * clon], [-slat * clon, -y]];
let hz = [[-z, 0.0], [0.0, 0.0]];
for axis_a in 0..2 {
for axis_b in 0..2 {
h[[row, 1, axis_a, axis_b]] = hx[axis_a][axis_b];
h[[row, 2, axis_a, axis_b]] = hy[axis_a][axis_b];
h[[row, 3, axis_a, axis_b]] = hz[axis_a][axis_b];
}
}
let pair = |hf: [[f64; 2]; 2],
df: [f64; 2],
f: f64,
hg: [[f64; 2]; 2],
dg: [f64; 2],
g: f64|
-> [[f64; 2]; 2] {
let mut out = [[0.0; 2]; 2];
for axis_a in 0..2 {
for axis_b in 0..2 {
out[axis_a][axis_b] = hf[axis_a][axis_b] * g
+ df[axis_a] * dg[axis_b]
+ df[axis_b] * dg[axis_a]
+ f * hg[axis_a][axis_b];
}
}
out
};
let hxy = pair(hx, dx, x, hy, dy, y);
let hyz = pair(hy, dy, y, hz, dz, z);
let hxz = pair(hx, dx, x, hz, dz, z);
for axis_a in 0..2 {
for axis_b in 0..2 {
h[[row, 4, axis_a, axis_b]] = hxy[axis_a][axis_b];
h[[row, 5, axis_a, axis_b]] = hyz[axis_a][axis_b];
h[[row, 6, axis_a, axis_b]] = hxz[axis_a][axis_b];
}
}
}
Ok(h)
}
}
impl SaeBasisThirdJet for SphereChartEvaluator {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String> {
if coords.ncols() != 2 {
return Err(format!(
"SphereChartEvaluator::third_jet expects latent_dim == 2, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let mut t3 = Array5::<f64>::zeros((n, 7, 2, 2, 2));
let single = |lat: &[f64; 4], lon: &[f64; 4], ax: [usize; 3]| -> f64 {
let n_lat = ax.iter().filter(|&&q| q == 0).count();
lat[n_lat] * lon[3 - n_lat]
};
let product = |f_lat: &[f64; 4],
f_lon: &[f64; 4],
g_lat: &[f64; 4],
g_lon: &[f64; 4],
ax: [usize; 3]|
-> f64 {
let mut acc = 0.0;
for mask in 0u8..8 {
let (mut f_lat_n, mut f_lon_n, mut g_lat_n, mut g_lon_n) = (0, 0, 0, 0);
for (i, &axis) in ax.iter().enumerate() {
let to_f = (mask >> i) & 1 == 1;
match (to_f, axis == 0) {
(true, true) => f_lat_n += 1,
(true, false) => f_lon_n += 1,
(false, true) => g_lat_n += 1,
(false, false) => g_lon_n += 1,
}
}
acc += f_lat[f_lat_n] * f_lon[f_lon_n] * g_lat[g_lat_n] * g_lon[g_lon_n];
}
acc
};
for row in 0..n {
let lat = coords[[row, 0]];
let lon = coords[[row, 1]];
let clat = lat.cos();
let slat = lat.sin();
let clon = lon.cos();
let slon = lon.sin();
let cos_lat = [clat, -slat, -clat, slat];
let sin_lat = [slat, clat, -slat, -clat];
let cos_lon = [clon, -slon, -clon, slon];
let sin_lon = [slon, clon, -slon, -clon];
let const_lon = [1.0, 0.0, 0.0, 0.0];
let (x_lat, x_lon) = (&cos_lat, &cos_lon);
let (y_lat, y_lon) = (&cos_lat, &sin_lon);
let (z_lat, z_lon) = (&sin_lat, &const_lon);
for axis_a in 0..2 {
for axis_b in 0..2 {
for axis_c in 0..2 {
let ax = [axis_a, axis_b, axis_c];
t3[[row, 1, axis_a, axis_b, axis_c]] = single(x_lat, x_lon, ax);
t3[[row, 2, axis_a, axis_b, axis_c]] = single(y_lat, y_lon, ax);
t3[[row, 3, axis_a, axis_b, axis_c]] = single(z_lat, z_lon, ax);
t3[[row, 4, axis_a, axis_b, axis_c]] =
product(x_lat, x_lon, y_lat, y_lon, ax);
t3[[row, 5, axis_a, axis_b, axis_c]] =
product(y_lat, y_lon, z_lat, z_lon, ax);
t3[[row, 6, axis_a, axis_b, axis_c]] =
product(x_lat, x_lon, z_lat, z_lon, ax);
}
}
}
}
Ok(t3)
}
}
#[derive(Debug, Clone)]
pub struct TorusHarmonicEvaluator {
pub latent_dim: usize,
pub num_harmonics: usize,
}
impl TorusHarmonicEvaluator {
pub fn new(latent_dim: usize, num_harmonics: usize) -> Result<Self, String> {
if latent_dim == 0 {
return Err("TorusHarmonicEvaluator requires latent_dim >= 1".to_string());
}
if num_harmonics == 0 {
return Err("TorusHarmonicEvaluator requires num_harmonics >= 1".to_string());
}
Ok(Self {
latent_dim,
num_harmonics,
})
}
pub fn axis_basis_size(&self) -> usize {
2 * self.num_harmonics + 1
}
pub fn basis_size(&self) -> usize {
let axis_m = self.axis_basis_size();
let mut total: usize = 1;
for _ in 0..self.latent_dim {
total = total
.checked_mul(axis_m)
.expect("TorusHarmonicEvaluator: basis size overflowed usize");
}
total
}
}
impl SaeBasisEvaluator for TorusHarmonicEvaluator {
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
let expected = self.basis_size();
if n_basis != expected {
return Err(format!(
"TorusHarmonicEvaluator::phi_eta_split: n_basis {n_basis} != evaluator width {expected}"
));
}
let d = self.latent_dim;
let axis_m = self.axis_basis_size();
let mut curved = Vec::with_capacity(n_basis);
let mut idx = vec![0usize; d];
for _flat in 0..n_basis {
let mut nonconstant_axes = 0usize;
let mut has_higher_harmonic = false;
for &axis_col in &idx {
if axis_col > 0 {
nonconstant_axes += 1;
if axis_col > 2 {
has_higher_harmonic = true;
}
}
}
curved.push(has_higher_harmonic || nonconstant_axes > 1);
for axis in (0..d).rev() {
idx[axis] += 1;
if idx[axis] < axis_m {
break;
}
idx[axis] = 0;
}
}
Ok(PhiEtaSplit::from_curved_mask(curved))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
Some(<Self as SaeBasisThirdJet>::third_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
let d = self.latent_dim;
if coords.ncols() != d {
return Err(format!(
"TorusHarmonicEvaluator: expected latent_dim {d}, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let axis_m = self.axis_basis_size();
let m = self.basis_size();
let h_max = self.num_harmonics;
let two_pi = 2.0 * std::f64::consts::PI;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, d));
let mut phi_axis = vec![vec![0.0_f64; axis_m]; d];
let mut dphi_axis = vec![vec![0.0_f64; axis_m]; d];
for row in 0..n {
for axis in 0..d {
let t = coords[[row, axis]];
phi_axis[axis][0] = 1.0;
dphi_axis[axis][0] = 0.0;
for h in 1..=h_max {
let freq = two_pi * (h as f64);
let angle = freq * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * h - 1;
let c_idx = 2 * h;
phi_axis[axis][s_idx] = s;
phi_axis[axis][c_idx] = c;
dphi_axis[axis][s_idx] = freq * c;
dphi_axis[axis][c_idx] = -freq * s;
}
}
let mut idx = vec![0usize; d];
for flat in 0..m {
let mut val = 1.0_f64;
for axis in 0..d {
val *= phi_axis[axis][idx[axis]];
}
phi[[row, flat]] = val;
for axis_target in 0..d {
let mut deriv = 1.0_f64;
for axis in 0..d {
deriv *= if axis == axis_target {
dphi_axis[axis][idx[axis]]
} else {
phi_axis[axis][idx[axis]]
};
}
jet[[row, flat, axis_target]] = deriv;
}
for axis in (0..d).rev() {
idx[axis] += 1;
if idx[axis] < axis_m {
break;
}
idx[axis] = 0;
}
}
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for TorusHarmonicEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
let d = self.latent_dim;
if coords.ncols() != d {
return Err(format!(
"TorusHarmonicEvaluator::second_jet expects latent_dim == {d}, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let axis_m = self.axis_basis_size();
let m = self.basis_size();
let h_max = self.num_harmonics;
let two_pi = 2.0 * std::f64::consts::PI;
let mut hess = Array4::<f64>::zeros((n, m, d, d));
let mut phi_axis = vec![vec![0.0_f64; axis_m]; d];
let mut dphi_axis = vec![vec![0.0_f64; axis_m]; d];
let mut d2phi_axis = vec![vec![0.0_f64; axis_m]; d];
for row in 0..n {
for axis in 0..d {
let t = coords[[row, axis]];
phi_axis[axis][0] = 1.0;
dphi_axis[axis][0] = 0.0;
d2phi_axis[axis][0] = 0.0;
for k in 1..=h_max {
let freq = two_pi * (k as f64);
let freq2 = freq * freq;
let angle = freq * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * k - 1;
let c_idx = 2 * k;
phi_axis[axis][s_idx] = s;
phi_axis[axis][c_idx] = c;
dphi_axis[axis][s_idx] = freq * c;
dphi_axis[axis][c_idx] = -freq * s;
d2phi_axis[axis][s_idx] = -freq2 * s;
d2phi_axis[axis][c_idx] = -freq2 * c;
}
}
let mut idx = vec![0usize; d];
for flat in 0..m {
for axis_a in 0..d {
for axis_b in 0..d {
let mut prod = 1.0_f64;
for axis in 0..d {
let factor = if axis == axis_a && axis == axis_b {
d2phi_axis[axis][idx[axis]]
} else if axis == axis_a || axis == axis_b {
dphi_axis[axis][idx[axis]]
} else {
phi_axis[axis][idx[axis]]
};
prod *= factor;
}
hess[[row, flat, axis_a, axis_b]] = prod;
}
}
for axis in (0..d).rev() {
idx[axis] += 1;
if idx[axis] < axis_m {
break;
}
idx[axis] = 0;
}
}
}
Ok(hess)
}
}
impl SaeBasisThirdJet for TorusHarmonicEvaluator {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String> {
let d = self.latent_dim;
if coords.ncols() != d {
return Err(format!(
"TorusHarmonicEvaluator::third_jet expects latent_dim == {d}, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let axis_m = self.axis_basis_size();
let m = self.basis_size();
let h_max = self.num_harmonics;
let two_pi = 2.0 * std::f64::consts::PI;
let mut t3 = Array5::<f64>::zeros((n, m, d, d, d));
let mut deriv_axis = vec![vec![vec![0.0_f64; axis_m]; 4]; d];
for row in 0..n {
for axis in 0..d {
let t = coords[[row, axis]];
for order in 0..4 {
deriv_axis[axis][order][0] = 0.0;
}
deriv_axis[axis][0][0] = 1.0;
for k in 1..=h_max {
let freq = two_pi * (k as f64);
let freq2 = freq * freq;
let freq3 = freq2 * freq;
let angle = freq * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * k - 1;
let c_idx = 2 * k;
deriv_axis[axis][0][s_idx] = s;
deriv_axis[axis][0][c_idx] = c;
deriv_axis[axis][1][s_idx] = freq * c;
deriv_axis[axis][1][c_idx] = -freq * s;
deriv_axis[axis][2][s_idx] = -freq2 * s;
deriv_axis[axis][2][c_idx] = -freq2 * c;
deriv_axis[axis][3][s_idx] = -freq3 * c;
deriv_axis[axis][3][c_idx] = freq3 * s;
}
}
let mut idx = vec![0usize; d];
for flat in 0..m {
for axis_a in 0..d {
for axis_b in 0..d {
for axis_c in 0..d {
let mut prod = 1.0_f64;
for axis in 0..d {
let order = (axis == axis_a) as usize
+ (axis == axis_b) as usize
+ (axis == axis_c) as usize;
prod *= deriv_axis[axis][order][idx[axis]];
}
t3[[row, flat, axis_a, axis_b, axis_c]] = prod;
}
}
}
for axis in (0..d).rev() {
idx[axis] += 1;
if idx[axis] < axis_m {
break;
}
idx[axis] = 0;
}
}
}
Ok(t3)
}
}
#[derive(Debug, Clone)]
pub struct AffineCoordinateEvaluator {
pub latent_dim: usize,
}
impl AffineCoordinateEvaluator {
pub fn new(latent_dim: usize) -> Self {
Self { latent_dim }
}
}
impl SaeBasisEvaluator for AffineCoordinateEvaluator {
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
let expected = self.latent_dim + 1;
if n_basis != expected {
return Err(format!(
"AffineCoordinateEvaluator::phi_eta_split: n_basis {n_basis} != {expected}"
));
}
Ok(PhiEtaSplit::all_linear(n_basis))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
Some(<Self as SaeBasisThirdJet>::third_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"AffineCoordinateEvaluator: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let n = coords.nrows();
let m = self.latent_dim + 1;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, self.latent_dim));
phi.column_mut(0).fill(1.0);
for row in 0..n {
for axis in 0..self.latent_dim {
phi[[row, axis + 1]] = coords[[row, axis]];
jet[[row, axis + 1, axis]] = 1.0;
}
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for AffineCoordinateEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"AffineCoordinateEvaluator::second_jet: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let n = coords.nrows();
let m = self.latent_dim + 1;
let d = self.latent_dim;
Ok(Array4::<f64>::zeros((n, m, d, d)))
}
}
impl SaeBasisThirdJet for AffineCoordinateEvaluator {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"AffineCoordinateEvaluator::third_jet: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let n = coords.nrows();
let m = self.latent_dim + 1;
let d = self.latent_dim;
Ok(Array5::<f64>::zeros((n, m, d, d, d)))
}
}
#[derive(Debug, Clone)]
pub struct DuchonCoordinateEvaluator {
pub centers: Array2<f64>,
pub order: crate::basis::DuchonNullspaceOrder,
}
impl DuchonCoordinateEvaluator {
pub fn new(centers: Array2<f64>, m: usize) -> Result<Self, String> {
if centers.ncols() == 0 {
return Err("DuchonCoordinateEvaluator: centers must have at least one column".into());
}
if m == 0 {
return Err("DuchonCoordinateEvaluator: Duchon m must be at least 1".into());
}
let order = match m {
1 => crate::basis::DuchonNullspaceOrder::Zero,
2 => crate::basis::DuchonNullspaceOrder::Linear,
other => crate::basis::DuchonNullspaceOrder::Degree(other - 1),
};
Ok(Self { centers, order })
}
}
impl SaeBasisEvaluator for DuchonCoordinateEvaluator {
fn affine_transformed_evaluator(
&self,
shift: &[f64],
scale: &[f64],
n_basis: usize,
) -> Result<Option<Arc<dyn SaeBasisSecondJet>>, String> {
let dim = self.centers.ncols();
if shift.len() != dim || scale.len() != dim {
return Err(format!(
"DuchonCoordinateEvaluator::affine_transformed_evaluator: affine vectors must have length {dim}; got shift={} scale={}",
shift.len(),
scale.len()
));
}
if n_basis == usize::MAX {
return Err(
"DuchonCoordinateEvaluator::affine_transformed_evaluator: unreachable basis width"
.to_string(),
);
}
if dim != 1 {
return Ok(None);
}
if !(scale[0].is_finite() && scale[0] > 0.0 && shift[0].is_finite()) {
return Ok(None);
}
let mut centers = self.centers.clone();
for row in 0..centers.nrows() {
centers[[row, 0]] = (centers[[row, 0]] - shift[0]) / scale[0];
}
Ok(Some(Arc::new(Self {
centers,
order: self.order,
})))
}
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
let dim = self.centers.ncols();
let effective = duchon_effective_order_for_eta(self.centers.view(), self.order);
let n_poly = duchon_polynomial_column_count(dim, effective);
if n_basis < n_poly {
return Err(format!(
"DuchonCoordinateEvaluator::phi_eta_split: n_basis {n_basis} smaller than polynomial block {n_poly}"
));
}
let n_kernel = n_basis - n_poly;
let mut curved = vec![false; n_basis];
for col in 0..n_kernel {
curved[col] = true;
}
if let crate::basis::DuchonNullspaceOrder::Degree(degree) = effective {
let linear_mask = monomial_linear_mask(dim, degree);
if linear_mask.len() != n_poly {
return Err(format!(
"DuchonCoordinateEvaluator::phi_eta_split: polynomial mask width {} != {n_poly}",
linear_mask.len()
));
}
for (local_col, linear) in linear_mask.into_iter().enumerate() {
if !linear {
curved[n_kernel + local_col] = true;
}
}
}
Ok(PhiEtaSplit::from_curved_mask(curved))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
Some(<Self as SaeBasisThirdJet>::third_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != self.centers.ncols() {
return Err(format!(
"DuchonCoordinateEvaluator: expected latent_dim {}, got {}",
self.centers.ncols(),
coords.ncols()
));
}
crate::basis::duchon_sae_atom_basis_with_jet(coords, self.centers.view(), self.order)
.map_err(|err| err.to_string())
}
}
impl SaeBasisSecondJet for DuchonCoordinateEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
if coords.ncols() != self.centers.ncols() {
return Err(format!(
"DuchonCoordinateEvaluator::second_jet: expected latent_dim {}, got {}",
self.centers.ncols(),
coords.ncols()
));
}
crate::basis::duchon_sae_atom_second_jet(coords, self.centers.view(), self.order)
.map_err(|err| err.to_string())
}
}
impl SaeBasisThirdJet for DuchonCoordinateEvaluator {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String> {
if coords.ncols() != self.centers.ncols() {
return Err(format!(
"DuchonCoordinateEvaluator::third_jet: expected latent_dim {}, got {}",
self.centers.ncols(),
coords.ncols()
));
}
crate::basis::duchon_sae_atom_third_jet(coords, self.centers.view(), self.order)
.map_err(|err| err.to_string())
}
}
#[derive(Debug, Clone)]
pub struct EuclideanPatchEvaluator {
pub latent_dim: usize,
pub max_degree: usize,
}
impl EuclideanPatchEvaluator {
pub fn new(latent_dim: usize, max_degree: usize) -> Result<Self, String> {
if latent_dim == 0 {
return Err("EuclideanPatchEvaluator: latent_dim must be positive".into());
}
Ok(Self {
latent_dim,
max_degree,
})
}
pub fn basis_size(&self) -> usize {
crate::basis::monomial_exponents(self.latent_dim, self.max_degree).len()
}
fn order(&self) -> crate::basis::DuchonNullspaceOrder {
match self.max_degree {
0 => crate::basis::DuchonNullspaceOrder::Zero,
1 => crate::basis::DuchonNullspaceOrder::Linear,
k => crate::basis::DuchonNullspaceOrder::Degree(k),
}
}
}
impl SaeBasisEvaluator for EuclideanPatchEvaluator {
fn affine_transformed_evaluator(
&self,
shift: &[f64],
scale: &[f64],
n_basis: usize,
) -> Result<Option<Arc<dyn SaeBasisSecondJet>>, String> {
if shift.len() != self.latent_dim || scale.len() != self.latent_dim {
return Err(format!(
"EuclideanPatchEvaluator::affine_transformed_evaluator: affine vectors must have length {}; got shift={} scale={}",
self.latent_dim,
shift.len(),
scale.len()
));
}
if n_basis != self.basis_size() {
return Err(format!(
"EuclideanPatchEvaluator::affine_transformed_evaluator: n_basis {n_basis} != evaluator width {}",
self.basis_size()
));
}
if shift.iter().chain(scale.iter()).any(|v| !v.is_finite())
|| scale.iter().any(|&v| v <= 0.0)
{
return Ok(None);
}
Ok(Some(Arc::new(Self {
latent_dim: self.latent_dim,
max_degree: self.max_degree,
})))
}
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
let linear_mask = monomial_linear_mask(self.latent_dim, self.max_degree);
if linear_mask.len() != n_basis {
return Err(format!(
"EuclideanPatchEvaluator::phi_eta_split: polynomial mask width {} != n_basis {n_basis}",
linear_mask.len()
));
}
Ok(PhiEtaSplit::from_curved_mask(
linear_mask.into_iter().map(|linear| !linear).collect(),
))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
Some(<Self as SaeBasisThirdJet>::third_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"EuclideanPatchEvaluator: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let exponents = crate::basis::monomial_exponents(self.latent_dim, self.max_degree);
let n = coords.nrows();
let m = exponents.len();
let mut phi = Array2::<f64>::zeros((n, m));
for (col, alpha) in exponents.iter().enumerate() {
for row in 0..n {
let mut value = 1.0_f64;
for (axis, &exp) in alpha.iter().enumerate() {
if exp != 0 {
value *= coords[[row, axis]].powi(exp as i32);
}
}
phi[[row, col]] = value;
}
}
let jet = crate::basis::duchon_polynomial_first_derivative_nd(coords, self.order());
if jet.shape() != [n, m, self.latent_dim] {
return Err(format!(
"EuclideanPatchEvaluator: monomial jet shape {:?} disagrees with ({n}, {m}, {})",
jet.shape(),
self.latent_dim
));
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for EuclideanPatchEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"EuclideanPatchEvaluator::second_jet: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let exponents = crate::basis::monomial_exponents(self.latent_dim, self.max_degree);
let n = coords.nrows();
let m = exponents.len();
let d = self.latent_dim;
let mut hess = Array4::<f64>::zeros((n, m, d, d));
for (col, alpha) in exponents.iter().enumerate() {
for a in 0..d {
if alpha[a] == 0 {
continue;
}
for c in 0..d {
if a != c && alpha[c] == 0 {
continue;
}
let lead = if a == c {
(alpha[a] as f64) * (alpha[a].saturating_sub(1) as f64)
} else {
(alpha[a] as f64) * (alpha[c] as f64)
};
if lead == 0.0 {
continue;
}
for row in 0..n {
let mut value = lead;
for axis in 0..d {
let mut exp = alpha[axis];
if axis == a {
exp = exp.saturating_sub(1);
}
if axis == c {
exp = exp.saturating_sub(1);
}
if exp != 0 {
value *= coords[[row, axis]].powi(exp as i32);
}
}
hess[[row, col, a, c]] = value;
}
}
}
}
Ok(hess)
}
}
impl SaeBasisThirdJet for EuclideanPatchEvaluator {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"EuclideanPatchEvaluator::third_jet: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let exponents = crate::basis::monomial_exponents(self.latent_dim, self.max_degree);
let n = coords.nrows();
let m = exponents.len();
let d = self.latent_dim;
let mut t3 = Array5::<f64>::zeros((n, m, d, d, d));
let falling = |alpha: usize, k: usize| -> f64 {
let mut acc = 1.0_f64;
for j in 0..k {
acc *= (alpha as f64) - (j as f64);
}
acc
};
for (col, alpha) in exponents.iter().enumerate() {
for a in 0..d {
if alpha[a] == 0 {
continue;
}
for b in 0..d {
for c in 0..d {
let mut order = vec![0usize; d];
order[a] += 1;
order[b] += 1;
order[c] += 1;
if (0..d).any(|axis| order[axis] > alpha[axis]) {
continue;
}
let mut lead = 1.0_f64;
for axis in 0..d {
lead *= falling(alpha[axis], order[axis]);
}
if lead == 0.0 {
continue;
}
for row in 0..n {
let mut value = lead;
for axis in 0..d {
let exp = alpha[axis] - order[axis];
if exp != 0 {
value *= coords[[row, axis]].powi(exp as i32);
}
}
t3[[row, col, a, b, c]] = value;
}
}
}
}
}
Ok(t3)
}
}
const SAE_FRAME_RANK_CUTOFF: f64 = 1.0e-7;
const SAE_FRAME_ACTIVATION_MARGIN: f64 = 0.25;
#[derive(Debug, Clone)]
pub struct GrassmannFrame {
frame: Array2<f64>,
gauge_singular_values: Array1<f64>,
}
impl GrassmannFrame {
pub fn output_dim(&self) -> usize {
self.frame.nrows()
}
pub fn rank(&self) -> usize {
self.frame.ncols()
}
pub fn gauge_singular_values(&self) -> &Array1<f64> {
&self.gauge_singular_values
}
pub fn frame(&self) -> ArrayView2<'_, f64> {
self.frame.view()
}
pub fn manifold_dimension(&self) -> usize {
let r = self.rank();
let p = self.output_dim();
r * (p - r)
}
fn from_oriented(mut frame: Array2<f64>, gauge_singular_values: Array1<f64>) -> Self {
let (p, r) = frame.dim();
for col in 0..r {
let mut pivot_abs = 0.0_f64;
let mut pivot_val = 0.0_f64;
for row in 0..p {
let v = frame[[row, col]];
if v.abs() > pivot_abs {
pivot_abs = v.abs();
pivot_val = v;
}
}
if pivot_val < 0.0 {
for row in 0..p {
frame[[row, col]] = -frame[[row, col]];
}
}
}
Self {
frame,
gauge_singular_values,
}
}
pub fn polar_update(cross_moment: ArrayView2<'_, f64>) -> Result<Self, String> {
let (p, r) = cross_moment.dim();
if p == 0 || r == 0 {
return Err("GrassmannFrame::polar_update: cross-moment must be non-empty".into());
}
if r > p {
return Err(format!(
"GrassmannFrame::polar_update: frame rank r={r} cannot exceed output dim p={p}"
));
}
let owned = cross_moment.to_owned();
let (u_opt, sv, vt_opt) = owned
.svd(true, true)
.map_err(|e| format!("GrassmannFrame::polar_update: SVD failed: {e}"))?;
let w = u_opt.ok_or_else(|| {
"GrassmannFrame::polar_update: thin SVD returned no left factor".to_string()
})?;
let vt = vt_opt.ok_or_else(|| {
"GrassmannFrame::polar_update: thin SVD returned no right factor".to_string()
})?;
let polar = fast_ab(&w, &vt);
Ok(Self::from_oriented(polar, sv))
}
pub fn reconstruct_decoder(&self, coords: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
if coords.ncols() != self.rank() {
return Err(format!(
"GrassmannFrame::reconstruct_decoder: coord cols {} must equal frame rank {}",
coords.ncols(),
self.rank()
));
}
Ok(fast_abt(&coords.to_owned(), &self.frame))
}
pub fn project_decoder(&self, decoder: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
if decoder.ncols() != self.output_dim() {
return Err(format!(
"GrassmannFrame::project_decoder: decoder cols {} must equal output dim {}",
decoder.ncols(),
self.output_dim()
));
}
Ok(fast_ab(&decoder.to_owned(), &self.frame))
}
pub fn max_principal_angle(&self, other: ArrayView2<'_, f64>) -> Result<f64, String> {
if other.nrows() != self.output_dim() {
return Err(format!(
"GrassmannFrame::max_principal_angle: other rows {} must equal output dim {}",
other.nrows(),
self.output_dim()
));
}
let other_owned = other.to_owned();
let overlap = fast_atb(&self.frame, &other_owned);
let (_u, sv_cos, _vt) = overlap
.svd(false, false)
.map_err(|e| format!("GrassmannFrame::max_principal_angle: cos-SVD failed: {e}"))?;
let u_overlap = fast_ab(&self.frame, &overlap);
let v_perp = &other_owned - &u_overlap;
let (_u, sv_sin, _vt) = v_perp
.svd(false, false)
.map_err(|e| format!("GrassmannFrame::max_principal_angle: sin-SVD failed: {e}"))?;
let min_cos = sv_cos
.iter()
.copied()
.fold(1.0_f64, f64::min)
.clamp(0.0, 1.0);
let max_sin = sv_sin
.iter()
.copied()
.fold(0.0_f64, f64::max)
.clamp(0.0, 1.0);
Ok(max_sin.atan2(min_cos))
}
}
#[derive(Debug, Clone)]
pub struct GrassmannCrossMoment {
moment: Array2<f64>,
}
impl GrassmannCrossMoment {
pub fn new(output_dim: usize, rank: usize) -> Self {
Self {
moment: Array2::<f64>::zeros((output_dim, rank)),
}
}
pub fn accumulate(
&mut self,
targets: ArrayView2<'_, f64>,
coords: ArrayView2<'_, f64>,
) -> Result<(), String> {
if targets.ncols() != self.moment.nrows() || coords.ncols() != self.moment.ncols() {
return Err(format!(
"GrassmannCrossMoment::accumulate: expected targets (·,{}) and coords (·,{}); \
got (·,{}) and (·,{})",
self.moment.nrows(),
self.moment.ncols(),
targets.ncols(),
coords.ncols()
));
}
if targets.nrows() != coords.nrows() {
return Err(format!(
"GrassmannCrossMoment::accumulate: targets rows {} must equal coords rows {}",
targets.nrows(),
coords.nrows()
));
}
let block = fast_atb(&targets.to_owned(), &coords.to_owned());
self.moment += █
Ok(())
}
pub fn moment(&self) -> ArrayView2<'_, f64> {
self.moment.view()
}
pub fn polar_frame(&self) -> Result<GrassmannFrame, String> {
GrassmannFrame::polar_update(self.moment.view())
}
}
#[derive(Debug, Clone)]
pub struct SaeManifoldAtom {
pub name: String,
pub basis_kind: SaeAtomBasisKind,
pub latent_dim: usize,
pub basis_values: Array2<f64>,
pub basis_jacobian: Array3<f64>,
pub decoder_coefficients: Array2<f64>,
pub smooth_penalty: Array2<f64>,
pub smooth_penalty_raw: Array2<f64>,
pub smooth_penalty_order: usize,
pub basis_evaluator: Option<Arc<dyn SaeBasisEvaluator>>,
pub basis_second_jet: Option<Arc<dyn SaeBasisSecondJet>>,
pub decoder_frame: Option<GrassmannFrame>,
pub homotopy_eta: f64,
}
impl SaeManifoldAtom {
#[must_use = "build error must be handled"]
pub fn new(
name: impl Into<String>,
basis_kind: SaeAtomBasisKind,
latent_dim: usize,
basis_values: Array2<f64>,
basis_jacobian: Array3<f64>,
decoder_coefficients: Array2<f64>,
smooth_penalty: Array2<f64>,
) -> Result<Self, String> {
let n = basis_values.nrows();
let m = basis_values.ncols();
let p = decoder_coefficients.ncols();
if basis_jacobian.dim() != (n, m, latent_dim) {
return Err(format!(
"SaeManifoldAtom::new: basis_jacobian must be ({n}, {m}, {latent_dim}); got {:?}",
basis_jacobian.dim()
));
}
if decoder_coefficients.nrows() != m {
return Err(format!(
"SaeManifoldAtom::new: decoder rows {} must equal basis size {m}",
decoder_coefficients.nrows()
));
}
if smooth_penalty.dim() != (m, m) {
return Err(format!(
"SaeManifoldAtom::new: smooth penalty must be ({m}, {m}); got {:?}",
smooth_penalty.dim()
));
}
if p == 0 {
return Err("SaeManifoldAtom::new: decoder output dimension must be positive".into());
}
let smooth_penalty_order = smooth_penalty_nullity(&smooth_penalty)?;
let mut atom = Self {
name: name.into(),
basis_kind,
latent_dim,
basis_values,
decoder_coefficients,
smooth_penalty_raw: smooth_penalty.clone(),
smooth_penalty,
smooth_penalty_order,
basis_jacobian,
basis_evaluator: None,
basis_second_jet: None,
decoder_frame: None,
homotopy_eta: 1.0,
};
atom.refresh_intrinsic_smooth_penalty();
Ok(atom)
}
pub fn with_basis_evaluator(mut self, evaluator: Arc<dyn SaeBasisEvaluator>) -> Self {
self.basis_evaluator = Some(evaluator);
self.basis_second_jet = None;
self
}
pub fn with_basis_second_jet(mut self, evaluator: Arc<dyn SaeBasisSecondJet>) -> Self {
let base: Arc<dyn SaeBasisEvaluator> = evaluator.clone();
self.basis_evaluator = Some(base);
self.basis_second_jet = Some(evaluator);
self
}
pub fn refresh_basis(&mut self, coords: ArrayView2<'_, f64>) -> Result<(), String> {
let Some(evaluator) = self.basis_evaluator.as_ref() else {
return Ok(());
};
let (phi, jet) = if self.homotopy_eta == 1.0 {
evaluator.evaluate(coords)?
} else {
let evaluated = evaluator.evaluate_phi_eta(coords, self.homotopy_eta)?;
(evaluated.phi, evaluated.jet)
};
if phi.dim() != self.basis_values.dim() {
return Err(format!(
"SaeManifoldAtom::refresh_basis: evaluator returned Phi {:?}, expected {:?}",
phi.dim(),
self.basis_values.dim()
));
}
if jet.dim() != self.basis_jacobian.dim() {
return Err(format!(
"SaeManifoldAtom::refresh_basis: evaluator returned jet {:?}, expected {:?}",
jet.dim(),
self.basis_jacobian.dim()
));
}
self.basis_values = phi;
self.basis_jacobian = jet;
Ok(())
}
pub fn n_obs(&self) -> usize {
self.basis_values.nrows()
}
pub fn basis_size(&self) -> usize {
self.basis_values.ncols()
}
pub fn output_dim(&self) -> usize {
self.decoder_coefficients.ncols()
}
pub fn border_frame_rank(&self) -> usize {
match &self.decoder_frame {
Some(frame) => frame.rank(),
None => self.output_dim(),
}
}
pub fn border_coeff_count(&self) -> usize {
self.basis_size() * self.border_frame_rank()
}
pub fn frame_manifold_dimension(&self) -> usize {
match &self.decoder_frame {
Some(frame) => frame.manifold_dimension(),
None => 0,
}
}
pub fn decoder_numerical_rank(&self) -> Result<usize, String> {
let p = self.output_dim();
if p == 0 || self.basis_size() == 0 {
return Ok(0);
}
let (_u, sv, _vt) = self
.decoder_coefficients
.svd(false, false)
.map_err(|e| format!("SaeManifoldAtom::decoder_numerical_rank: SVD failed: {e}"))?;
let max_sv = sv.iter().copied().fold(0.0_f64, f64::max);
if !(max_sv > 0.0) {
return Ok(0);
}
let tol = SAE_FRAME_RANK_CUTOFF * max_sv;
Ok(sv.iter().filter(|&&v| v > tol).count())
}
pub fn decoder_frame_activation_rank(&self) -> Result<Option<usize>, String> {
let p = self.output_dim();
if p == 0 || self.basis_size() == 0 {
return Ok(None);
}
let numerical_rank = self.decoder_numerical_rank()?;
let r = numerical_rank.max(1).min(p);
let shrink_ok = (r as f64) <= (p as f64) * (1.0 - SAE_FRAME_ACTIVATION_MARGIN);
if !shrink_ok || p.saturating_sub(r) == 0 {
return Ok(None);
}
Ok(Some(r))
}
pub fn maybe_activate_decoder_frame(&mut self) -> Result<Option<usize>, String> {
let Some(r) = self.decoder_frame_activation_rank()? else {
self.decoder_frame = None;
return Ok(None);
};
let p = self.output_dim();
let (_w, sv, vt_opt) = self.decoder_coefficients.svd(false, true).map_err(|e| {
format!("SaeManifoldAtom::maybe_activate_decoder_frame: SVD failed: {e}")
})?;
let vt = vt_opt.ok_or_else(|| {
"SaeManifoldAtom::maybe_activate_decoder_frame: SVD returned no right factor"
.to_string()
})?;
let available = vt.nrows();
let r_eff = r.min(available);
if r_eff == 0 || p.saturating_sub(r_eff) == 0 {
self.decoder_frame = None;
return Ok(None);
}
let mut frame = Array2::<f64>::zeros((p, r_eff));
for col in 0..r_eff {
for row in 0..p {
frame[[row, col]] = vt[[col, row]];
}
}
let mut gauge = Array1::<f64>::zeros(r_eff);
for i in 0..r_eff {
gauge[i] = sv.get(i).copied().unwrap_or(0.0);
}
self.decoder_frame = Some(GrassmannFrame::from_oriented(frame, gauge));
let u_proj = self
.decoder_frame
.as_ref()
.expect("frame just set")
.frame()
.to_owned();
let c_proj = self.decoder_coefficients.dot(&u_proj);
self.decoder_coefficients = c_proj.dot(&u_proj.t());
Ok(Some(r_eff))
}
pub fn deactivate_decoder_frame(&mut self) {
self.decoder_frame = None;
}
pub fn factored_coordinates(&self) -> Result<Option<Array2<f64>>, String> {
match &self.decoder_frame {
Some(frame) => Ok(Some(
frame.project_decoder(self.decoder_coefficients.view())?,
)),
None => Ok(None),
}
}
pub fn reconstruct_decoder_coefficients(
&self,
coords: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, String> {
let frame = self.decoder_frame.as_ref().ok_or_else(|| {
"SaeManifoldAtom::reconstruct_decoder_coefficients: no active frame".to_string()
})?;
frame.reconstruct_decoder(coords)
}
pub fn set_factored_coordinates(&mut self, coords: ArrayView2<'_, f64>) -> Result<(), String> {
let reconstructed = self.reconstruct_decoder_coefficients(coords)?;
if reconstructed.dim() != self.decoder_coefficients.dim() {
return Err(format!(
"SaeManifoldAtom::set_factored_coordinates: reconstructed decoder {:?} \
must match {:?}",
reconstructed.dim(),
self.decoder_coefficients.dim()
));
}
self.decoder_coefficients = reconstructed;
Ok(())
}
pub fn refresh_frame_from_cross_moment(
&mut self,
cross_moment: ArrayView2<'_, f64>,
) -> Result<(), String> {
if self.decoder_frame.is_none() {
return Err("SaeManifoldAtom::refresh_frame_from_cross_moment: no active frame".into());
}
let new_frame = GrassmannFrame::polar_update(cross_moment)?;
if new_frame.output_dim() != self.output_dim() {
return Err(format!(
"SaeManifoldAtom::refresh_frame_from_cross_moment: frame output dim {} \
must equal decoder output dim {}",
new_frame.output_dim(),
self.output_dim()
));
}
let coords = new_frame.project_decoder(self.decoder_coefficients.view())?;
self.decoder_coefficients = new_frame.reconstruct_decoder(coords.view())?;
self.decoder_frame = Some(new_frame);
Ok(())
}
pub fn decoded_row(&self, row: usize) -> Array1<f64> {
let p = self.output_dim();
let mut out = Array1::<f64>::zeros(p);
self.fill_decoded_row(row, out.as_slice_mut().expect("contiguous"));
out
}
pub fn fill_decoded_row(&self, row: usize, out: &mut [f64]) {
let p = self.output_dim();
let m = self.basis_size();
assert_eq!(out.len(), p);
for slot in out.iter_mut() {
*slot = 0.0;
}
for basis_col in 0..m {
let phi = self.basis_values[[row, basis_col]];
if phi == 0.0 {
continue;
}
for out_col in 0..p {
out[out_col] += phi * self.decoder_coefficients[[basis_col, out_col]];
}
}
}
pub fn decoded_derivative_row(&self, row: usize, latent_axis: usize) -> Array1<f64> {
let p = self.output_dim();
let mut out = Array1::<f64>::zeros(p);
self.fill_decoded_derivative_row(row, latent_axis, out.as_slice_mut().expect("contiguous"));
out
}
pub fn fill_decoded_derivative_row(&self, row: usize, latent_axis: usize, out: &mut [f64]) {
let p = self.output_dim();
let m = self.basis_size();
assert_eq!(out.len(), p);
for slot in out.iter_mut() {
*slot = 0.0;
}
for basis_col in 0..m {
let dphi = self.basis_jacobian[[row, basis_col, latent_axis]];
if dphi == 0.0 {
continue;
}
for out_col in 0..p {
out[out_col] += dphi * self.decoder_coefficients[[basis_col, out_col]];
}
}
}
pub fn refresh_intrinsic_smooth_penalty(&mut self) {
let m = self.basis_size();
if m == 0 || self.smooth_penalty_order == 0 || self.latent_dim != 1 {
self.smooth_penalty.assign(&self.smooth_penalty_raw);
return;
}
let n = self.n_obs();
let p = self.output_dim();
let beta = 0.5 - self.smooth_penalty_order as f64;
let mut act = vec![0.0_f64; m];
let mut num = vec![0.0_f64; m];
let mut deriv = vec![0.0_f64; p];
for row in 0..n {
self.fill_decoded_derivative_row(row, 0, &mut deriv);
let mut speed_sq = 0.0_f64;
for &d in deriv.iter() {
speed_sq += d * d;
}
for col in 0..m {
let phi = self.basis_values[[row, col]];
let w = phi * phi;
if w == 0.0 {
continue;
}
act[col] += w;
num[col] += w * speed_sq;
}
}
let mut speeds = vec![0.0_f64; m];
let mut log_acc = 0.0_f64;
let mut log_cnt = 0usize;
for col in 0..m {
let s = if act[col] > 0.0 {
num[col] / act[col]
} else {
0.0
};
speeds[col] = s;
if s > 0.0 && s.is_finite() {
log_acc += s.ln();
log_cnt += 1;
}
}
let center = if log_cnt > 0 {
(log_acc / log_cnt as f64).exp()
} else {
0.0
};
if !(center > 0.0 && center.is_finite()) {
self.smooth_penalty.assign(&self.smooth_penalty_raw);
return;
}
const RELATIVE_SPEED_FLOOR: f64 = 1.0e-6;
const RELATIVE_SPEED_CEIL: f64 = 1.0e6;
let mut root_w = vec![0.0_f64; m];
for col in 0..m {
let ratio = speeds[col] / center;
let ratio = if ratio.is_finite() {
ratio.clamp(RELATIVE_SPEED_FLOOR, RELATIVE_SPEED_CEIL)
} else {
RELATIVE_SPEED_CEIL
};
root_w[col] = ratio.powf(0.5 * beta);
}
for i in 0..m {
let ri = root_w[i];
for j in 0..m {
self.smooth_penalty[[i, j]] = ri * self.smooth_penalty_raw[[i, j]] * root_w[j];
}
}
}
}
fn smooth_penalty_nullity(s: &Array2<f64>) -> Result<usize, String> {
let m = s.ncols();
if m == 0 {
return Ok(0);
}
let mut sym = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
sym[[i, j]] = 0.5 * (s[[i, j]] + s[[j, i]]);
}
}
let (evals, _evecs) = sym
.eigh(Side::Lower)
.map_err(|e| format!("smooth_penalty_nullity: eigh failed: {e}"))?;
let max_eig = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v));
if !(max_eig > 0.0) {
return Ok(0);
}
let tol = SAE_MANIFOLD_SPECTRAL_RANK_CUTOFF * max_eig;
Ok(evals.iter().filter(|&&v| v <= tol).count())
}
#[derive(Debug, Clone, Copy)]
pub enum AssignmentMode {
Softmax { temperature: f64, sparsity: f64 },
IBPMap {
temperature: f64,
alpha: f64,
learnable_alpha: bool,
},
JumpReLU { temperature: f64, threshold: f64 },
}
impl AssignmentMode {
#[must_use]
pub fn softmax(temperature: f64) -> Self {
Self::Softmax {
temperature,
sparsity: 1.0,
}
}
#[must_use]
pub fn ibp_map(temperature: f64, alpha: f64, learnable_alpha: bool) -> Self {
Self::IBPMap {
temperature,
alpha,
learnable_alpha,
}
}
#[must_use]
pub fn jumprelu(temperature: f64, threshold: f64) -> Self {
Self::JumpReLU {
temperature,
threshold,
}
}
pub fn temperature(&self) -> f64 {
match *self {
AssignmentMode::Softmax { temperature, .. }
| AssignmentMode::IBPMap { temperature, .. }
| AssignmentMode::JumpReLU { temperature, .. } => temperature,
}
}
fn set_temperature(&mut self, new_temperature: f64) -> Result<(), String> {
if !(new_temperature.is_finite() && new_temperature > 0.0) {
return Err(format!(
"AssignmentMode: temperature must be finite and positive; got {new_temperature}"
));
}
match self {
AssignmentMode::Softmax { temperature, .. }
| AssignmentMode::IBPMap { temperature, .. }
| AssignmentMode::JumpReLU { temperature, .. } => {
*temperature = new_temperature;
}
}
Ok(())
}
fn validate(&self) -> Result<(), String> {
let temperature = self.temperature();
if !(temperature.is_finite() && temperature > 0.0) {
return Err(format!(
"AssignmentMode: temperature must be finite and positive; got {temperature}"
));
}
match *self {
AssignmentMode::Softmax { sparsity, .. } => {
if !(sparsity.is_finite() && sparsity > 0.0) {
return Err(format!(
"AssignmentMode::Softmax: sparsity must be finite and positive; got {sparsity}"
));
}
}
AssignmentMode::IBPMap { alpha, .. } => {
if !(alpha.is_finite() && alpha > 0.0) {
return Err(format!(
"AssignmentMode::IBPMap: alpha must be finite and positive; got {alpha}"
));
}
}
AssignmentMode::JumpReLU { threshold, .. } => {
if !threshold.is_finite() {
return Err(format!(
"AssignmentMode::JumpReLU: threshold must be finite; got {threshold}"
));
}
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SaeAssignment {
pub logits: Array2<f64>,
pub coords: Vec<LatentCoordValues>,
pub mode: AssignmentMode,
}
impl SaeAssignment {
#[must_use = "build error must be handled"]
pub fn new(
logits: Array2<f64>,
coords: Vec<LatentCoordValues>,
temperature: f64,
) -> Result<Self, String> {
Self::with_mode(logits, coords, AssignmentMode::softmax(temperature))
}
#[must_use = "build error must be handled"]
pub fn with_mode(
mut logits: Array2<f64>,
coords: Vec<LatentCoordValues>,
mode: AssignmentMode,
) -> Result<Self, String> {
mode.validate()?;
let n = logits.nrows();
let k = logits.ncols();
if coords.len() != k {
return Err(format!(
"SaeAssignment::new: coords length {} must equal K={k}",
coords.len()
));
}
for (atom, coord) in coords.iter().enumerate() {
if coord.n_obs() != n {
return Err(format!(
"SaeAssignment::new: coord atom {atom} has n_obs={} but logits has {n}",
coord.n_obs()
));
}
}
for row in 0..n {
validate_finite_logits(logits.row(row), row)?;
}
if matches!(mode, AssignmentMode::Softmax { .. }) {
canonicalize_softmax_logits(&mut logits);
}
Ok(Self {
logits,
coords,
mode,
})
}
pub fn n_obs(&self) -> usize {
self.logits.nrows()
}
pub fn k_atoms(&self) -> usize {
self.logits.ncols()
}
pub fn total_coord_dim(&self) -> usize {
self.coords.iter().map(|c| c.latent_dim()).sum()
}
pub fn assignment_coord_dim(&self) -> usize {
match self.mode {
AssignmentMode::Softmax { .. } => self.k_atoms().saturating_sub(1),
AssignmentMode::IBPMap { .. } | AssignmentMode::JumpReLU { .. } => self.k_atoms(),
}
}
pub fn row_block_dim(&self) -> usize {
self.assignment_coord_dim() + self.total_coord_dim()
}
pub fn coord_offsets(&self) -> Vec<usize> {
let mut out = Vec::with_capacity(self.k_atoms());
let mut cursor = self.assignment_coord_dim();
for coord in &self.coords {
out.push(cursor);
cursor += coord.latent_dim();
}
out
}
pub fn assignments(&self) -> Array2<f64> {
let n = self.n_obs();
let k = self.k_atoms();
let mut out = Array2::<f64>::zeros((n, k));
for row in 0..n {
let a = self.assignments_row(row);
for atom in 0..k {
out[[row, atom]] = a[atom];
}
}
out
}
pub fn assignments_row(&self, row: usize) -> Array1<f64> {
self.try_assignments_row(row)
.expect("assignment logits must be finite")
}
pub fn try_assignments_row(&self, row: usize) -> Result<Array1<f64>, String> {
validate_finite_logits(self.logits.row(row), row)?;
if self.k_atoms() == 1 && matches!(self.mode, AssignmentMode::Softmax { .. }) {
return Ok(Array1::from_vec(vec![1.0]));
}
match self.mode {
AssignmentMode::Softmax { temperature, .. } => {
Ok(softmax_row(self.logits.row(row), temperature))
}
AssignmentMode::IBPMap {
temperature, alpha, ..
} => Ok(ibp_map_row(self.logits.row(row), temperature, alpha)),
AssignmentMode::JumpReLU {
temperature,
threshold,
} => Ok(jumprelu_row(self.logits.row(row), temperature, threshold)),
}
}
pub fn flatten_ext_coords(&self) -> Array1<f64> {
let n = self.n_obs();
let q = self.row_block_dim();
let k = self.k_atoms();
let assignment_dim = self.assignment_coord_dim();
let offsets = self.coord_offsets();
let mut out = Array1::<f64>::zeros(n * q);
for row in 0..n {
let base = row * q;
for atom in 0..assignment_dim {
out[base + atom] = self.logits[[row, atom]];
}
for atom in 0..k {
let d = self.coords[atom].latent_dim();
let t_row = self.coords[atom].row(row);
for axis in 0..d {
out[base + offsets[atom] + axis] = t_row[axis];
}
}
}
out
}
#[must_use = "build error must be handled"]
pub fn from_blocks_with_mode(
logits: Array2<f64>,
coord_blocks: Vec<Array2<f64>>,
mode: AssignmentMode,
) -> Result<Self, String> {
let coords = coord_blocks
.iter()
.map(|c| LatentCoordValues::from_matrix(c.view(), LatentIdMode::None))
.collect();
Self::with_mode(logits, coords, mode)
}
#[must_use = "build error must be handled"]
pub fn from_blocks_with_mode_and_manifolds(
logits: Array2<f64>,
coord_blocks: Vec<Array2<f64>>,
manifolds: Vec<LatentManifold>,
mode: AssignmentMode,
) -> Result<Self, String> {
if coord_blocks.len() != manifolds.len() {
return Err(format!(
"SaeAssignment::from_blocks_with_mode_and_manifolds: coord block length {} != manifold length {}",
coord_blocks.len(),
manifolds.len()
));
}
let coords = coord_blocks
.iter()
.zip(manifolds)
.map(|(c, manifold)| {
LatentCoordValues::from_matrix_with_manifold(c.view(), LatentIdMode::None, manifold)
})
.collect();
Self::with_mode(logits, coords, mode)
}
}
#[derive(Debug, Clone)]
pub struct SaeManifoldRho {
pub log_lambda_sparse: f64,
pub log_lambda_smooth: f64,
pub log_ard: Vec<Array1<f64>>,
}
impl SaeManifoldRho {
#[must_use]
pub fn new(log_lambda_sparse: f64, log_lambda_smooth: f64, log_ard: Vec<Array1<f64>>) -> Self {
Self {
log_lambda_sparse,
log_lambda_smooth,
log_ard,
}
}
pub fn lambda_sparse(&self) -> f64 {
Self::stable_exp_strength(self.log_lambda_sparse)
}
pub fn lambda_smooth(&self) -> f64 {
Self::stable_exp_strength(self.log_lambda_smooth)
}
pub(crate) fn stable_exp_strength(log_strength: f64) -> f64 {
const MAX_LOG_STRENGTH: f64 = 700.0;
const MIN_LOG_STRENGTH: f64 = -700.0;
log_strength.clamp(MIN_LOG_STRENGTH, MAX_LOG_STRENGTH).exp()
}
pub fn to_flat(&self) -> Array1<f64> {
let ard_len: usize = self.log_ard.iter().map(|a| a.len()).sum();
let mut out = Array1::<f64>::zeros(2 + ard_len);
out[0] = self.log_lambda_sparse;
out[1] = self.log_lambda_smooth;
let mut cursor = 2usize;
for axis in &self.log_ard {
for &v in axis.iter() {
out[cursor] = v;
cursor += 1;
}
}
out
}
pub fn from_flat(&self, flat: ArrayView1<'_, f64>) -> SaeManifoldRho {
let ard_len: usize = self.log_ard.iter().map(|a| a.len()).sum();
assert_eq!(
flat.len(),
2 + ard_len,
"SaeManifoldRho::from_flat: flat length {} != 2 + Σ d_k = {}",
flat.len(),
2 + ard_len
);
let mut log_ard = Vec::with_capacity(self.log_ard.len());
let mut cursor = 2usize;
for axis in &self.log_ard {
let d = axis.len();
let mut block = Array1::<f64>::zeros(d);
for (j, slot) in block.iter_mut().enumerate() {
*slot = flat[cursor + j];
}
cursor += d;
log_ard.push(block);
}
SaeManifoldRho {
log_lambda_sparse: flat[0],
log_lambda_smooth: flat[1],
log_ard,
}
}
}
pub trait SaeKroneckerRow {
fn apply_jbeta(&self, row: usize, x_beta: &[f64], u_out: &mut [f64]);
fn scatter_jbeta_t(&self, row: usize, u: &[f64], y_beta: &mut [f64]);
fn apply_l(&self, row: usize, u: &[f64], w_out: &mut [f64]);
fn apply_l_t(&self, row: usize, v: &[f64], u_out: &mut [f64]);
}
#[derive(Debug, Clone)]
pub struct SaeKroneckerRows {
p: usize,
a_phi: Vec<Vec<(usize, f64)>>,
local_jac: Vec<Vec<f64>>,
}
impl SaeKroneckerRows {
pub fn new(p: usize, a_phi: Vec<Vec<(usize, f64)>>, local_jac: Vec<Vec<f64>>) -> Self {
assert_eq!(
a_phi.len(),
local_jac.len(),
"SaeKroneckerRows: a_phi rows ({}) != local_jac rows ({})",
a_phi.len(),
local_jac.len(),
);
Self {
p,
a_phi,
local_jac,
}
}
}
impl SaeKroneckerRow for SaeKroneckerRows {
fn apply_jbeta(&self, row: usize, x_beta: &[f64], u_out: &mut [f64]) {
for val in u_out.iter_mut() {
*val = 0.0;
}
for &(beta_base, phi) in &self.a_phi[row] {
if phi == 0.0 {
continue;
}
for j in 0..self.p {
u_out[j] += phi * x_beta[beta_base + j];
}
}
}
fn scatter_jbeta_t(&self, row: usize, u: &[f64], y_beta: &mut [f64]) {
for &(beta_base, phi) in &self.a_phi[row] {
if phi == 0.0 {
continue;
}
for j in 0..self.p {
y_beta[beta_base + j] += phi * u[j];
}
}
}
fn apply_l(&self, row: usize, u: &[f64], w_out: &mut [f64]) {
let jac = &self.local_jac[row];
let q_i = jac.len() / self.p;
for c in 0..q_i {
let mut acc = 0.0_f64;
for j in 0..self.p {
acc += jac[c * self.p + j] * u[j];
}
w_out[c] = acc;
}
}
fn apply_l_t(&self, row: usize, v: &[f64], u_out: &mut [f64]) {
let jac = &self.local_jac[row];
let q_i = jac.len() / self.p;
for c in 0..q_i {
let vc = v[c];
if vc == 0.0 {
continue;
}
for j in 0..self.p {
u_out[j] += jac[c * self.p + j] * vc;
}
}
}
}
struct SaeFrameKroneckerRows {
inner: SaeKroneckerRows,
projection: FrameProjection,
factored_a_phi: Vec<Vec<(usize, f64, usize)>>,
}
impl SaeFrameKroneckerRows {
fn new(
p: usize,
projection: FrameProjection,
a_phi: Vec<Vec<(usize, f64)>>,
local_jac: Vec<Vec<f64>>,
) -> Result<Self, String> {
let mut factored_a_phi: Vec<Vec<(usize, f64, usize)>> = Vec::with_capacity(a_phi.len());
for row_loads in &a_phi {
let mut row_out: Vec<(usize, f64, usize)> = Vec::with_capacity(row_loads.len());
for &(beta_base, phi) in row_loads {
let mut atom_idx = None;
for k in 0..projection.basis_sizes.len() {
let lo = projection.beta_offsets[k];
let hi = lo + projection.basis_sizes[k] * p;
if beta_base >= lo && beta_base < hi {
atom_idx = Some(k);
break;
}
}
let k = atom_idx.ok_or_else(|| {
format!(
"SaeFrameKroneckerRows::new: beta_base {beta_base} not in any atom block"
)
})?;
let basis_col = (beta_base - projection.beta_offsets[k]) / p;
let c_base = projection.border_offsets[k] + basis_col * projection.ranks[k];
row_out.push((c_base, phi, k));
}
factored_a_phi.push(row_out);
}
let inner = SaeKroneckerRows::new(p, a_phi, local_jac);
Ok(Self {
inner,
projection,
factored_a_phi,
})
}
fn apply_jbeta_factored(&self, row: usize, x_c: &[f64], u_out: &mut [f64]) {
for val in u_out.iter_mut() {
*val = 0.0;
}
for &(c_base, phi, atom) in &self.factored_a_phi[row] {
if phi == 0.0 {
continue;
}
self.projection
.accumulate_row_lift(atom, c_base, phi, x_c, u_out);
}
}
fn scatter_jbeta_factored_t(&self, row: usize, u: &[f64], y_c: &mut [f64]) {
for &(c_base, phi, atom) in &self.factored_a_phi[row] {
if phi == 0.0 {
continue;
}
self.projection
.accumulate_row_project(atom, c_base, phi, u, y_c);
}
}
fn apply_l(&self, row: usize, u: &[f64], w_out: &mut [f64]) {
self.inner.apply_l(row, u, w_out);
}
fn apply_l_t(&self, row: usize, v: &[f64], u_out: &mut [f64]) {
self.inner.apply_l_t(row, v, u_out);
}
}
#[derive(Debug, Clone, Copy)]
pub struct SaeManifoldLoss {
pub data_fit: f64,
pub assignment_sparsity: f64,
pub smoothness: f64,
pub ard: f64,
}
impl SaeManifoldLoss {
pub const fn total(&self) -> f64 {
self.data_fit + self.assignment_sparsity + self.smoothness + self.ard
}
pub const fn evidence_proxy(&self) -> f64 {
-self.total()
}
}
#[derive(Debug, Clone)]
pub struct SaeOuterRhoGradientComponents {
pub explicit: Array1<f64>,
pub logdet_trace: Array1<f64>,
pub occam: Array1<f64>,
pub third_order_correction: Array1<f64>,
pub third_order_correction_available: bool,
}
impl SaeOuterRhoGradientComponents {
#[must_use]
pub fn gradient_excluding_unavailable_correction(&self) -> Array1<f64> {
&(&self.explicit + &self.logdet_trace) + &self.occam
}
#[must_use]
pub fn gradient_with_available_correction(&self) -> Array1<f64> {
assert!(
self.third_order_correction_available,
"gradient_with_available_correction: third-order correction channel \
is not populated for this fit; use \
gradient_excluding_unavailable_correction() and account for the \
missing term explicitly"
);
&self.gradient_excluding_unavailable_correction() + &self.third_order_correction
}
}
#[derive(Debug, Clone)]
pub struct SaeArrowVector {
pub t: Array1<f64>,
pub beta: Array1<f64>,
}
#[derive(Debug, Clone, Copy)]
enum SaeLocalRowVar {
Logit { atom: usize },
Coord { atom: usize, axis: usize },
}
#[derive(Debug, Clone)]
struct SaeBorderChannel {
atom: usize,
basis_col: usize,
index: usize,
output: Vec<f64>,
}
#[derive(Debug, Clone)]
struct SaeRowJets {
vars: Vec<SaeLocalRowVar>,
first: Vec<Vec<f64>>,
second: Vec<Vec<Vec<f64>>>,
beta: Vec<Vec<f64>>,
beta_deriv: Vec<Vec<Vec<f64>>>,
beta_l_deriv: Vec<Vec<Vec<f64>>>,
}
fn sae_dot(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
fn sae_sigmoid_derivatives_from_value(value: f64, inv_tau: f64, scale: f64) -> (f64, f64, f64) {
let sig = if scale > 0.0 { value / scale } else { 0.0 };
let dz = scale * sig * (1.0 - sig) * inv_tau;
let d2z = scale * sig * (1.0 - sig) * (1.0 - 2.0 * sig) * inv_tau * inv_tau;
(value, dz, d2z)
}
pub const SHAPE_BAND_MAX_POINTS: usize = 512;
pub const SAE_DECODER_COV_PAYLOAD_MAX_ENTRIES: usize = 1 << 24;
#[derive(Debug, Clone)]
pub struct SaeAtomShapeUncertainty {
pub decoder_covariance: Option<Array2<f64>>,
pub band_coords: Array2<f64>,
pub band_mean: Array2<f64>,
pub band_sd: Array2<f64>,
}
#[derive(Debug, Clone)]
pub struct SaeShapeUncertainty {
pub dispersion: f64,
pub atoms: Vec<SaeAtomShapeUncertainty>,
}
#[derive(Debug, Clone)]
pub struct SaeRowLayout {
pub active_atoms: Vec<Vec<usize>>,
pub coord_starts: Vec<Vec<usize>>,
pub coord_offsets_full: Vec<usize>,
pub coord_dims: Vec<usize>,
}
impl SaeRowLayout {
fn from_jumprelu(
n: usize,
k_atoms: usize,
threshold: f64,
temperature: f64,
logits: &Array2<f64>,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let mut per_row = Vec::with_capacity(n);
for row in 0..n {
let row_logits = logits.row(row);
let active: Vec<usize> = (0..k_atoms)
.filter(|&k| jumprelu_in_optimization_band(row_logits[k], threshold, temperature))
.collect();
per_row.push(active);
}
Self::from_active_atoms(per_row, coord_dims, coord_offsets_full)
}
fn from_dense_weights(
assignments: &[Array1<f64>],
k_active_cap: usize,
cutoff: f64,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let cap = k_active_cap.max(1);
let mut per_row = Vec::with_capacity(assignments.len());
for a in assignments {
let k = a.len();
let mut idx: Vec<usize> = (0..k).collect();
idx.sort_by(|&i, &j| {
a[j].abs()
.partial_cmp(&a[i].abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut active: Vec<usize> = idx
.iter()
.copied()
.take(cap)
.filter(|&k_idx| a[k_idx].abs() > cutoff)
.collect();
if active.is_empty() {
if let Some(&top) = idx.first() {
active.push(top);
}
}
active.sort_unstable();
per_row.push(active);
}
Self::from_active_atoms(per_row, coord_dims, coord_offsets_full)
}
fn from_active_atoms(
active_atoms: Vec<Vec<usize>>,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let mut coord_starts_all = Vec::with_capacity(active_atoms.len());
for active in &active_atoms {
let mut starts = Vec::with_capacity(active.len());
let mut cursor = active.len();
for &k in active {
starts.push(cursor);
cursor += coord_dims[k];
}
coord_starts_all.push(starts);
}
Self {
active_atoms,
coord_starts: coord_starts_all,
coord_offsets_full,
coord_dims,
}
}
pub fn row_q_active(&self, row: usize) -> usize {
let active = &self.active_atoms[row];
let coord_sum: usize = active.iter().map(|&k| self.coord_dims[k]).sum();
active.len() + coord_sum
}
pub fn expand_row(&self, row: usize, delta_t_row: &[f64], out: &mut [f64]) {
for v in out.iter_mut() {
*v = 0.0;
}
let active = &self.active_atoms[row];
let starts = &self.coord_starts[row];
for (j, &k) in active.iter().enumerate() {
out[k] = delta_t_row[j];
let d = self.coord_dims[k];
let full_off = self.coord_offsets_full[k];
for axis in 0..d {
out[full_off + axis] = delta_t_row[starts[j] + axis];
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum GlobalOptimalityVerdict {
CertifiedGlobal { margin: f64 },
Uncertified { margin: f64 },
}
impl GlobalOptimalityVerdict {
pub fn margin(&self) -> f64 {
match self {
Self::CertifiedGlobal { margin } | Self::Uncertified { margin } => *margin,
}
}
pub fn is_certified(&self) -> bool {
matches!(self, Self::CertifiedGlobal { .. })
}
}
pub const SAE_CERT_CURVATURE_CONSTANT: f64 = 1.0;
pub const SAE_CERT_INCOHERENCE_BUDGET: f64 = 0.125;
pub fn curved_dictionary_global_optimality_verdict(
mu_hat: f64,
kappa_max: f64,
activity_floor: f64,
snr_proxy: f64,
k_atoms: usize,
) -> GlobalOptimalityVerdict {
if !mu_hat.is_finite()
|| !kappa_max.is_finite()
|| !activity_floor.is_finite()
|| !snr_proxy.is_finite()
|| k_atoms == 0
{
return GlobalOptimalityVerdict::Uncertified {
margin: f64::NEG_INFINITY,
};
}
let curvature_factor = 1.0 - SAE_CERT_CURVATURE_CONSTANT * kappa_max.max(0.0);
let snr_factor = 1.0 - 1.0 / snr_proxy;
if curvature_factor <= 0.0 || snr_factor <= 0.0 {
return GlobalOptimalityVerdict::Uncertified {
margin: f64::NEG_INFINITY,
};
}
let a = activity_floor.max(0.0);
let budget =
SAE_CERT_INCOHERENCE_BUDGET * a * a * snr_factor * curvature_factor / k_atoms as f64;
let margin = budget - mu_hat;
if margin > 0.0 {
GlobalOptimalityVerdict::CertifiedGlobal { margin }
} else {
GlobalOptimalityVerdict::Uncertified { margin }
}
}
#[derive(Clone, Debug)]
pub struct CertificateInputs {
pub mu_hat: f64,
pub per_atom_kappa_hat: Vec<f64>,
pub per_atom_mean_activity: Vec<f64>,
pub per_atom_peak_activity: Vec<f64>,
pub mean_activity_floor: f64,
pub peak_activity_floor: f64,
pub snr_proxy: f64,
pub dispersion: f64,
pub global_optimality: GlobalOptimalityVerdict,
pub note: String,
}
#[derive(Clone, Debug)]
pub struct SaeManifoldFitDiagnostics {
pub atom_two_lens: crate::inference::atom_lens::AtomTwoLensReport,
pub residual_gauge: crate::sae_identifiability::ResidualGaugeReport,
pub incoherence_report: Option<CertificateInputs>,
}
#[derive(Clone, Debug)]
pub struct SaeTrustDiagnostics {
pub atom_trust: Vec<f64>,
pub atoms: Vec<SaeAtomTrustDiagnostics>,
}
#[derive(Clone, Debug)]
pub struct SaeAtomTrustDiagnostics {
pub trust_score: f64,
pub sigma_min_tangent: f64,
pub sigma_max_tangent: f64,
pub tangent_condition_score: f64,
pub coverage: f64,
pub activation_frequency: f64,
pub untyped: bool,
pub active_token_count: usize,
}
pub fn dictionary_incoherence_report(term: &SaeManifoldTerm) -> Result<CertificateInputs, String> {
let dispersion = term.certificate_dispersion.ok_or_else(|| {
"dictionary_incoherence_report: fitted reconstruction dispersion is unavailable".to_string()
})?;
dictionary_incoherence_report_with_dispersion(term, dispersion)
}
pub fn dictionary_incoherence_report_with_dispersion(
term: &SaeManifoldTerm,
dispersion: f64,
) -> Result<CertificateInputs, String> {
if !dispersion.is_finite() || dispersion <= 0.0 {
return Err(format!(
"dictionary_incoherence_report: dispersion must be finite and positive, got {dispersion}"
));
}
let mu_hat = dictionary_frame_incoherence(term)?;
let per_atom_kappa_hat = term
.atoms
.iter()
.enumerate()
.map(|(atom_idx, _)| atom_curvature_bound(term, atom_idx))
.collect::<Result<Vec<_>, _>>()?;
let assignments = term.assignment.assignments();
let n = assignments.nrows();
let k_atoms = assignments.ncols();
let mut per_atom_mean_activity = Vec::with_capacity(k_atoms);
let mut per_atom_peak_activity = Vec::with_capacity(k_atoms);
for atom_idx in 0..k_atoms {
let mut sum = 0.0_f64;
let mut peak = 0.0_f64;
for row in 0..n {
let value = assignments[[row, atom_idx]];
sum += value;
peak = peak.max(value);
}
per_atom_mean_activity.push(if n > 0 { sum / n as f64 } else { 0.0 });
per_atom_peak_activity.push(peak);
}
let mean_activity_floor = per_atom_mean_activity
.iter()
.copied()
.fold(f64::INFINITY, f64::min);
let peak_activity_floor = per_atom_peak_activity
.iter()
.copied()
.fold(f64::INFINITY, f64::min);
let fitted = term.fitted();
let signal_power = if fitted.is_empty() {
0.0
} else {
fitted.iter().map(|v| v * v).sum::<f64>() / fitted.len() as f64
};
let mean_activity_floor = if mean_activity_floor.is_finite() {
mean_activity_floor
} else {
0.0
};
let peak_activity_floor = if peak_activity_floor.is_finite() {
peak_activity_floor
} else {
0.0
};
let snr_proxy = signal_power / dispersion;
let kappa_max = per_atom_kappa_hat.iter().copied().fold(0.0_f64, f64::max);
let global_optimality = curved_dictionary_global_optimality_verdict(
mu_hat,
kappa_max,
peak_activity_floor,
snr_proxy,
k_atoms,
);
let note = match global_optimality {
GlobalOptimalityVerdict::CertifiedGlobal { margin } => format!(
"global optimality CERTIFIED up to the residual gauge group \
(margin {margin:.3e}); μ̂={mu_hat:.3e}, κ̂_max={kappa_max:.3e}, \
a_floor={peak_activity_floor:.3e}, SNR={snr_proxy:.3e}"
),
GlobalOptimalityVerdict::Uncertified { margin } => format!(
"global optimality UNCERTIFIED (margin {margin:.3e}; cannot decide — \
multistart/homotopy genuinely needed); μ̂={mu_hat:.3e}, \
κ̂_max={kappa_max:.3e}, a_floor={peak_activity_floor:.3e}, \
SNR={snr_proxy:.3e}"
),
};
Ok(CertificateInputs {
mu_hat,
per_atom_kappa_hat,
per_atom_mean_activity,
per_atom_peak_activity,
mean_activity_floor,
peak_activity_floor,
snr_proxy,
dispersion,
global_optimality,
note,
})
}
fn dictionary_frame_incoherence(term: &SaeManifoldTerm) -> Result<f64, String> {
let frames = (0..term.k_atoms())
.map(|atom_idx| certificate_output_frame(term, atom_idx))
.collect::<Result<Vec<_>, _>>()?;
let mut mu = 0.0_f64;
for j in 0..frames.len() {
for k in (j + 1)..frames.len() {
if frames[j].ncols() == 0 || frames[k].ncols() == 0 {
continue;
}
let overlap = fast_atb(&frames[j], &frames[k]);
let (_u, s, _vt) = overlap.svd(false, false).map_err(|e| {
format!("dictionary_frame_incoherence: SVD failed for atom pair ({j}, {k}): {e}")
})?;
let pair = s.iter().copied().fold(0.0_f64, f64::max);
mu = mu.max(pair);
}
}
Ok(mu)
}
fn certificate_output_frame(
term: &SaeManifoldTerm,
atom_idx: usize,
) -> Result<Array2<f64>, String> {
let atom = &term.atoms[atom_idx];
if atom.decoder_frame.is_some() {
return Ok(term.frame_output_matrix(atom_idx));
}
let p = atom.output_dim();
let (_u, s, vt_opt) = atom
.decoder_coefficients
.svd(false, true)
.map_err(|e| format!("certificate_output_frame: SVD failed for atom {atom_idx}: {e}"))?;
let max_sv = s.iter().copied().fold(0.0_f64, f64::max);
if !(max_sv > 0.0) {
return Ok(Array2::<f64>::zeros((p, 0)));
}
let tol = SAE_FRAME_RANK_CUTOFF * max_sv;
let rank = s.iter().filter(|&&value| value > tol).count();
let vt = vt_opt.ok_or_else(|| {
format!("certificate_output_frame: SVD returned no right factor for atom {atom_idx}")
})?;
let rank = rank.min(vt.nrows());
let mut frame = Array2::<f64>::zeros((p, rank));
for col in 0..rank {
for row in 0..p {
frame[[row, col]] = vt[[col, row]];
}
}
Ok(frame)
}
fn atom_curvature_bound(term: &SaeManifoldTerm, atom_idx: usize) -> Result<f64, String> {
let atom = &term.atoms[atom_idx];
let coords = term.assignment.coords[atom_idx].as_matrix();
let second = atom
.basis_evaluator
.as_ref()
.and_then(|evaluator| evaluator.second_jet_dyn(coords.view()))
.ok_or_else(|| {
format!(
"atom_curvature_bound: atom {atom_idx} has no analytic second jet; cannot compute kappa_hat"
)
})?
.map_err(|e| format!("atom_curvature_bound: atom {atom_idx} second jet failed: {e}"))?;
let n = atom.n_obs();
let m = atom.basis_size();
let d = atom.latent_dim;
let p = atom.output_dim();
if second.dim() != (n, m, d, d) {
return Err(format!(
"atom_curvature_bound: atom {atom_idx} second jet shape {:?} must be ({n}, {m}, {d}, {d})",
second.dim()
));
}
let mut max_kappa = 0.0_f64;
let mut tangent = Array2::<f64>::zeros((p, d));
let mut second_vec = vec![0.0_f64; p];
for row in 0..n {
for axis in 0..d {
let mut col = vec![0.0_f64; p];
atom.fill_decoded_derivative_row(row, axis, &mut col);
for out in 0..p {
tangent[[out, axis]] = col[out];
}
}
let tangent_rank = tangent_frame_rank(tangent.view())?;
let tangent_scale = tangent_rank.0;
let q = tangent_rank.1;
for axis_a in 0..d {
for axis_b in 0..d {
second_vec.fill(0.0);
for basis_col in 0..m {
let h = second[[row, basis_col, axis_a, axis_b]];
if h == 0.0 {
continue;
}
for out in 0..p {
second_vec[out] += h * atom.decoder_coefficients[[basis_col, out]];
}
}
let perp_norm = projected_perp_norm(&second_vec, q.view());
if tangent_scale > 0.0 {
max_kappa = max_kappa.max(perp_norm / tangent_scale);
} else if perp_norm > 0.0 {
return Ok(f64::INFINITY);
}
}
}
}
Ok(max_kappa)
}
fn tangent_frame_rank(tangent: ArrayView2<'_, f64>) -> Result<(f64, Array2<f64>), String> {
let p = tangent.nrows();
let d = tangent.ncols();
if p == 0 || d == 0 {
return Ok((0.0, Array2::<f64>::zeros((p, 0))));
}
let (u_opt, s, _vt) = tangent
.to_owned()
.svd(true, false)
.map_err(|e| format!("tangent_frame_rank: SVD failed: {e}"))?;
let max_sv = s.iter().copied().fold(0.0_f64, f64::max);
if !(max_sv > 0.0) {
return Ok((0.0, Array2::<f64>::zeros((p, 0))));
}
let tol = SAE_FRAME_RANK_CUTOFF * max_sv;
let rank = s.iter().filter(|&&value| value > tol).count();
let min_positive = s
.iter()
.copied()
.filter(|value| *value > tol)
.fold(f64::INFINITY, f64::min);
let u = u_opt.ok_or_else(|| "tangent_frame_rank: SVD returned no U".to_string())?;
let rank = rank.min(u.ncols());
let mut q = Array2::<f64>::zeros((p, rank));
for col in 0..rank {
for row in 0..p {
q[[row, col]] = u[[row, col]];
}
}
Ok((min_positive * min_positive, q))
}
fn projected_perp_norm(vector: &[f64], tangent_frame: ArrayView2<'_, f64>) -> f64 {
let mut residual = vector.to_vec();
for axis in 0..tangent_frame.ncols() {
let mut coeff = 0.0_f64;
for out in 0..tangent_frame.nrows() {
coeff += tangent_frame[[out, axis]] * vector[out];
}
if coeff == 0.0 {
continue;
}
for out in 0..tangent_frame.nrows() {
residual[out] -= coeff * tangent_frame[[out, axis]];
}
}
residual.iter().map(|v| v * v).sum::<f64>().sqrt()
}
#[derive(Debug)]
pub struct SaeManifoldTerm {
pub atoms: Vec<SaeManifoldAtom>,
pub assignment: SaeAssignment,
temperature_schedule: Option<GumbelTemperatureSchedule>,
last_row_layout: Option<SaeRowLayout>,
row_metric: Option<crate::inference::row_metric::RowMetric>,
collapse_events: Vec<CollapseEvent>,
row_loss_weights: Option<Vec<f64>>,
last_frames_active: bool,
border_hbb_workspace: Array2<f64>,
certificate_dispersion: Option<f64>,
curvature_walk_report: Option<CurvatureWalkReport>,
}
impl Clone for SaeManifoldTerm {
fn clone(&self) -> Self {
Self {
atoms: self.atoms.clone(),
assignment: self.assignment.clone(),
temperature_schedule: self.temperature_schedule.clone(),
last_row_layout: self.last_row_layout.clone(),
row_metric: self.row_metric.clone(),
collapse_events: self.collapse_events.clone(),
row_loss_weights: self.row_loss_weights.clone(),
last_frames_active: self.last_frames_active,
border_hbb_workspace: Array2::<f64>::zeros((0, 0)),
certificate_dispersion: self.certificate_dispersion,
curvature_walk_report: self.curvature_walk_report.clone(),
}
}
}
#[derive(Debug)]
struct SaeManifoldMutableState {
atoms: Vec<(Array2<f64>, Array3<f64>, Array2<f64>, Array2<f64>)>,
logits: Array2<f64>,
coords: Vec<LatentCoordValues>,
last_row_layout: Option<SaeRowLayout>,
}
impl SaeManifoldTerm {
#[must_use = "build error must be handled"]
pub fn new(atoms: Vec<SaeManifoldAtom>, assignment: SaeAssignment) -> Result<Self, String> {
if atoms.is_empty() {
return Err("SaeManifoldTerm::new: at least one atom required".into());
}
let n = atoms[0].n_obs();
let p = atoms[0].output_dim();
if assignment.n_obs() != n || assignment.k_atoms() != atoms.len() {
return Err(format!(
"SaeManifoldTerm::new: assignment shape ({}, {}) does not match atoms ({n}, {})",
assignment.n_obs(),
assignment.k_atoms(),
atoms.len()
));
}
for (k, atom) in atoms.iter().enumerate() {
if atom.n_obs() != n {
return Err(format!(
"SaeManifoldTerm::new: atom {k} has n_obs={} but atom 0 has {n}",
atom.n_obs()
));
}
if atom.output_dim() != p {
return Err(format!(
"SaeManifoldTerm::new: atom {k} output_dim={} but atom 0 has {p}",
atom.output_dim()
));
}
if atom.latent_dim != assignment.coords[k].latent_dim() {
return Err(format!(
"SaeManifoldTerm::new: atom {k} latent_dim={} but assignment coord has {}",
atom.latent_dim,
assignment.coords[k].latent_dim()
));
}
}
Ok(Self {
atoms,
assignment,
temperature_schedule: None,
last_row_layout: None,
row_metric: None,
collapse_events: Vec::new(),
row_loss_weights: None,
last_frames_active: false,
border_hbb_workspace: Array2::<f64>::zeros((0, 0)),
certificate_dispersion: None,
curvature_walk_report: None,
})
}
pub fn set_certificate_dispersion(&mut self, dispersion: f64) -> Result<(), String> {
if !dispersion.is_finite() || dispersion <= 0.0 {
return Err(format!(
"SaeManifoldTerm::set_certificate_dispersion: dispersion must be finite and positive, got {dispersion}"
));
}
self.certificate_dispersion = Some(dispersion);
Ok(())
}
pub fn set_row_loss_weights(&mut self, weights: Vec<f64>) -> Result<(), String> {
if weights.len() != self.n_obs() {
return Err(format!(
"SaeManifoldTerm::set_row_loss_weights: {} weights for {} rows",
weights.len(),
self.n_obs()
));
}
if weights.is_empty() {
self.row_loss_weights = None;
return Ok(());
}
if !weights.iter().all(|w| w.is_finite() && *w > 0.0) {
return Err(
"SaeManifoldTerm::set_row_loss_weights: weights must be finite and strictly \
positive"
.to_string(),
);
}
let first = weights[0];
if weights.iter().all(|w| *w == first) {
self.row_loss_weights = None;
return Ok(());
}
let mean = weights.iter().sum::<f64>() / weights.len() as f64;
self.row_loss_weights = Some(weights.into_iter().map(|w| w / mean).collect());
Ok(())
}
pub fn row_loss_weights(&self) -> Option<&[f64]> {
self.row_loss_weights.as_deref()
}
pub fn clear_row_loss_weights(&mut self) {
self.row_loss_weights = None;
}
pub fn set_row_metric(
&mut self,
metric: crate::inference::row_metric::RowMetric,
) -> Result<(), String> {
if metric.n_rows() != self.n_obs() {
return Err(format!(
"SaeManifoldTerm::set_row_metric: metric has {} rows but term has {}",
metric.n_rows(),
self.n_obs()
));
}
if metric.p_out() != self.output_dim() {
return Err(format!(
"SaeManifoldTerm::set_row_metric: metric output dim {} but term has {}",
metric.p_out(),
self.output_dim()
));
}
self.row_metric = Some(metric);
Ok(())
}
pub fn row_metric(&self) -> Option<&crate::inference::row_metric::RowMetric> {
self.row_metric.as_ref()
}
fn diagnostic_metric(&self) -> Result<crate::inference::row_metric::RowMetric, String> {
match self.row_metric() {
Some(metric) => Ok(metric.clone()),
None => {
crate::inference::row_metric::RowMetric::euclidean(self.n_obs(), self.output_dim())
}
}
}
pub fn fit_diagnostics_report(
&self,
per_atom_ard_variances: Option<&[Option<Array1<f64>>]>,
isometry_pin_active: bool,
reconstruction_dispersion: Option<f64>,
) -> Result<SaeManifoldFitDiagnostics, String> {
let metric = self.diagnostic_metric()?;
let atom_two_lens = crate::inference::atom_lens::atom_two_lens(self, &metric);
let (certificate_model, streamed_curvature) =
self.to_residual_gauge_model(metric, per_atom_ard_variances, isometry_pin_active)?;
let views = self.atom_parameter_views();
let ops: Vec<Option<crate::sae_identifiability::OrbitPenaltyOperator>> =
if isometry_pin_active {
views
.iter()
.map(|view| {
view.as_ref().and_then(|v| {
crate::sae_identifiability::isometry_orbit_penalty_operator(v, 1.0)
})
})
.collect()
} else {
(0..self.k_atoms()).map(|_| None).collect()
};
let residual_gauge = if isometry_pin_active {
crate::sae_identifiability::residual_gauge_exact(&certificate_model, &views, &ops)?
} else {
let (curvature_gram, root_rows) = streamed_curvature.ok_or_else(|| {
"fit_diagnostics_report: missing streamed residual-gauge curvature for unpinned exact path"
.to_string()
})?;
crate::sae_identifiability::residual_gauge_exact_from_curvature_gram(
&certificate_model,
&views,
&ops,
curvature_gram,
root_rows,
)?
};
Ok(SaeManifoldFitDiagnostics {
atom_two_lens,
residual_gauge,
incoherence_report: match reconstruction_dispersion.or(self.certificate_dispersion) {
Some(dispersion) => Some(dictionary_incoherence_report_with_dispersion(
self, dispersion,
)?),
None => None,
},
})
}
pub fn trust_diagnostics_report(
&self,
assignments: ArrayView2<'_, f64>,
) -> Result<SaeTrustDiagnostics, String> {
let n = self.n_obs();
let k_atoms = self.k_atoms();
if assignments.dim() != (n, k_atoms) {
return Err(format!(
"trust_diagnostics_report: assignments shape {:?} must be ({n}, {k_atoms})",
assignments.dim()
));
}
if !assignments.iter().all(|v| v.is_finite()) {
return Err("trust_diagnostics_report: assignments must be finite".to_string());
}
let metric = self.diagnostic_metric()?;
let active_threshold = crate::inference::atom_lens::SAE_TRUST_ACTIVE_MASS_FLOOR;
let mut atoms = Vec::with_capacity(k_atoms);
let mut atom_trust = Vec::with_capacity(k_atoms);
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let mut active_token_count = 0usize;
let mut activation_sum = 0.0_f64;
for row in 0..n {
let mass = assignments[[row, atom_idx]];
activation_sum += mass;
if mass > active_threshold {
active_token_count += 1;
}
}
let coverage = if n > 0 {
active_token_count as f64 / n as f64
} else {
0.0
};
let activation_frequency = if n > 0 {
activation_sum / n as f64
} else {
0.0
};
let (sigma_min_tangent, sigma_max_tangent) = self
.atom_tangent_spectrum_from_assignments(
atom_idx,
assignments,
&metric,
active_threshold,
)?;
let tangent_condition_score = if sigma_max_tangent > 0.0 {
(sigma_min_tangent / sigma_max_tangent).clamp(0.0, 1.0)
} else {
0.0
};
let trust_score = tangent_condition_score;
atom_trust.push(trust_score);
atoms.push(SaeAtomTrustDiagnostics {
trust_score,
sigma_min_tangent,
sigma_max_tangent,
tangent_condition_score,
coverage,
activation_frequency,
untyped: matches!(atom.basis_kind, SaeAtomBasisKind::Precomputed(_)),
active_token_count,
});
}
Ok(SaeTrustDiagnostics { atom_trust, atoms })
}
fn atom_tangent_spectrum_from_assignments(
&self,
atom_idx: usize,
assignments: ArrayView2<'_, f64>,
metric: &crate::inference::row_metric::RowMetric,
active_threshold: f64,
) -> Result<(f64, f64), String> {
let atom = &self.atoms[atom_idx];
let d = atom.latent_dim;
let p = self.output_dim();
if d == 0 || p == 0 {
return Ok((0.0, 0.0));
}
let mut gram = Array2::<f64>::zeros((d, d));
let mut active_mass_sum = 0.0_f64;
let mut jac_row = vec![0.0_f64; p * d];
for row in 0..self.n_obs() {
let mass = assignments[[row, atom_idx]];
if !(mass > active_threshold) {
continue;
}
active_mass_sum += mass;
for axis in 0..d {
let start = axis;
let mut tangent = vec![0.0_f64; p];
atom.fill_decoded_derivative_row(row, axis, &mut tangent);
for out in 0..p {
jac_row[out * d + start] = tangent[out];
}
}
let row_pullback = metric.pullback(row, &jac_row, d);
for axis_a in 0..d {
for axis_b in 0..=axis_a {
gram[[axis_a, axis_b]] += mass * row_pullback[[axis_a, axis_b]];
}
}
jac_row.fill(0.0);
}
if !(active_mass_sum > 0.0) {
return Ok((0.0, 0.0));
}
let inv_mass = 1.0 / active_mass_sum;
for axis_a in 0..d {
for axis_b in 0..=axis_a {
let value = gram[[axis_a, axis_b]] * inv_mass;
gram[[axis_a, axis_b]] = value;
gram[[axis_b, axis_a]] = value;
}
}
let (evals, _) = gram.eigh(Side::Lower).map_err(|e| {
format!(
"trust_diagnostics_report: atom {atom_idx} tangent eigendecomposition failed: {e}"
)
})?;
let mut sigma_min = f64::INFINITY;
let mut sigma_max = 0.0_f64;
for value in evals.iter().copied() {
let clamped = value.max(0.0);
let sigma = clamped.sqrt();
sigma_min = sigma_min.min(sigma);
sigma_max = sigma_max.max(sigma);
}
if sigma_min.is_finite() {
Ok((sigma_min, sigma_max))
} else {
Ok((0.0, 0.0))
}
}
fn atom_parameter_views(&self) -> Vec<Option<crate::sae_identifiability::AtomParameterView>> {
let assignments = self.assignment.assignments();
let n = self.n_obs();
self.atoms
.iter()
.enumerate()
.map(|(k, atom)| {
if matches!(atom.basis_kind, SaeAtomBasisKind::Sphere) {
return None;
}
let coords = self.assignment.coords[k].as_matrix().to_owned();
if coords.nrows() != n || coords.ncols() != atom.latent_dim {
return None;
}
let mut activations = Array1::<f64>::zeros(n);
for row in 0..n {
activations[row] = assignments[[row, k]];
}
let basis_second_jet = atom
.basis_evaluator
.as_ref()
.and_then(|evaluator| evaluator.second_jet_dyn(coords.view()))
.and_then(|res| res.ok());
Some(crate::sae_identifiability::AtomParameterView {
basis_values: atom.basis_values.clone(),
basis_jacobian: atom.basis_jacobian.clone(),
decoder: atom.decoder_coefficients.clone(),
coords,
activations,
basis_second_jet,
})
})
.collect()
}
fn to_residual_gauge_model(
&self,
metric: crate::inference::row_metric::RowMetric,
per_atom_ard_variances: Option<&[Option<Array1<f64>>]>,
isometry_pin_active: bool,
) -> Result<
(
crate::sae_identifiability::FittedSaeManifold,
Option<(Array2<f64>, usize)>,
),
String,
> {
use crate::sae_identifiability::{AtomTopology, FittedAtom, FittedSaeManifold};
let n = self.n_obs();
let p = self.output_dim();
let k = self.k_atoms();
let assignments = self.assignment.assignments();
let mut fitted_atoms: Vec<FittedAtom> = Vec::with_capacity(k);
let mut atom_offsets: Vec<usize> = Vec::with_capacity(k);
let mut atom_axis_dim: Vec<usize> = Vec::with_capacity(k);
let mut cursor = 0usize;
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let d = atom.latent_dim;
let topology = match (&atom.basis_kind, d) {
(SaeAtomBasisKind::Periodic, 1) | (SaeAtomBasisKind::Torus, 1) => {
AtomTopology::Circle
}
(SaeAtomBasisKind::Periodic, _) | (SaeAtomBasisKind::Torus, _) => {
AtomTopology::Torus { latent_dim: d }
}
(SaeAtomBasisKind::Sphere, _) => AtomTopology::Sphere,
(
SaeAtomBasisKind::Duchon
| SaeAtomBasisKind::EuclideanPatch
| SaeAtomBasisKind::Precomputed(_),
_,
) => AtomTopology::EuclideanPatch { latent_dim: d },
};
let mut frame = Array2::<f64>::zeros((p, d));
let mut active_mass = 0.0_f64;
let mut tangent = vec![0.0_f64; p];
for row in 0..n {
let a_nk = assignments[[row, atom_idx]];
if !(a_nk > 0.0) {
continue;
}
active_mass += a_nk;
for axis in 0..d {
atom.fill_decoded_derivative_row(row, axis, &mut tangent);
for i in 0..p {
frame[[i, axis]] += a_nk * tangent[i];
}
}
}
if active_mass > 0.0 {
let inv = 1.0 / active_mass;
frame.mapv_inplace(|v| v * inv);
}
let mut disp_num = 0.0_f64;
let mut disp_den = 0.0_f64;
for row in 0..n {
let a_nk = assignments[[row, atom_idx]];
if !(a_nk > 0.0) {
continue;
}
for axis in 0..d {
atom.fill_decoded_derivative_row(row, axis, &mut tangent);
for i in 0..p {
let dev = tangent[i] - frame[[i, axis]];
disp_num += a_nk * dev * dev;
disp_den += a_nk * tangent[i] * tangent[i];
}
}
}
let lowering_error = if disp_den > 0.0 {
(disp_num / disp_den).clamp(0.0, 1.0)
} else {
0.0
};
let ard_variances = per_atom_ard_variances
.and_then(|all| all.get(atom_idx))
.and_then(|opt| opt.clone())
.filter(|v| v.len() == d);
fitted_atoms.push(FittedAtom {
name: atom.name.clone(),
topology,
frame,
ard_variances,
lowering_error,
});
atom_offsets.push(cursor);
atom_axis_dim.push(d);
cursor += p * d;
}
let param_dim = cursor;
let (jacobian_rows, streamed_curvature) = if isometry_pin_active {
let mut jacobian_rows: Vec<Vec<f64>> = Vec::with_capacity(n);
let mut tangent = vec![0.0_f64; p];
for row in 0..n {
let mut j_flat = vec![0.0_f64; p * param_dim];
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let a_nk = assignments[[row, atom_idx]];
if !(a_nk > 0.0) {
continue;
}
let d = atom_axis_dim[atom_idx];
let base = atom_offsets[atom_idx];
for axis in 0..d {
atom.fill_decoded_derivative_row(row, axis, &mut tangent);
for i in 0..p {
j_flat[i * param_dim + base + i * d + axis] += a_nk * tangent[i];
}
}
}
jacobian_rows.push(j_flat);
}
(jacobian_rows, None)
} else {
let streamed = self.residual_gauge_streamed_data_curvature(
&metric,
&atom_offsets,
&atom_axis_dim,
param_dim,
)?;
(Vec::new(), Some(streamed))
};
let isometry_penalty_root = if isometry_pin_active && param_dim > 0 {
let mut root_rows: Vec<Array1<f64>> = Vec::new();
for (atom_idx, fitted) in fitted_atoms.iter().enumerate() {
let d = atom_axis_dim[atom_idx];
let base = atom_offsets[atom_idx];
for axis in 0..d {
let mut r = Array1::<f64>::zeros(param_dim);
let mut any = false;
for i in 0..p {
let v = fitted.frame[[i, axis]];
if v != 0.0 {
any = true;
}
r[base + i * d + axis] = v;
}
if any {
root_rows.push(r);
}
}
}
let mut root = Array2::<f64>::zeros((root_rows.len(), param_dim));
for (ri, r) in root_rows.iter().enumerate() {
root.row_mut(ri).assign(r);
}
root
} else {
Array2::<f64>::zeros((0, param_dim))
};
Ok((
FittedSaeManifold {
atoms: fitted_atoms,
jacobian_rows,
isometry_penalty_root,
metric,
},
streamed_curvature,
))
}
fn residual_gauge_streamed_data_curvature(
&self,
metric: &crate::inference::row_metric::RowMetric,
atom_offsets: &[usize],
atom_axis_dim: &[usize],
param_dim: usize,
) -> Result<(Array2<f64>, usize), String> {
let n = self.n_obs();
let p = self.output_dim();
if metric.p_out() != p {
return Err(format!(
"residual_gauge_streamed_data_curvature: metric output dim {} but term has {p}",
metric.p_out()
));
}
let rank = metric.metric_rank();
let mut gram = Array2::<f64>::zeros((param_dim, param_dim));
if param_dim == 0 || n == 0 || rank == 0 {
return Ok((gram, n * rank));
}
let assignments = self.assignment.assignments();
let mut tangent = vec![0.0_f64; p];
let mut j_flat = vec![0.0_f64; p * param_dim];
let mut root_row = Array1::<f64>::zeros(param_dim);
for row in 0..n {
j_flat.fill(0.0);
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let a_nk = assignments[[row, atom_idx]];
if !(a_nk > 0.0) {
continue;
}
let d = atom_axis_dim[atom_idx];
let base = atom_offsets[atom_idx];
for axis in 0..d {
atom.fill_decoded_derivative_row(row, axis, &mut tangent);
for i in 0..p {
j_flat[i * param_dim + base + i * d + axis] += a_nk * tangent[i];
}
}
}
if metric.drives_gauge() {
for r in 0..rank {
root_row.fill(0.0);
for c in 0..param_dim {
let mut acc = 0.0_f64;
for i in 0..p {
acc += metric.factor_entry(row, i, r) * j_flat[i * param_dim + c];
}
root_row[c] = acc;
}
let row_slice = root_row.as_slice().ok_or_else(|| {
"residual_gauge_streamed_data_curvature: non-contiguous root row"
.to_string()
})?;
Self::accumulate_residual_gauge_gram_row(&mut gram, row_slice);
}
} else {
for i in 0..p {
let start = i * param_dim;
let end = start + param_dim;
Self::accumulate_residual_gauge_gram_row(&mut gram, &j_flat[start..end]);
}
}
}
for a in 0..param_dim {
for b in 0..a {
gram[[b, a]] = gram[[a, b]];
}
}
Ok((gram, n * rank))
}
fn accumulate_residual_gauge_gram_row(gram: &mut Array2<f64>, row: &[f64]) {
for a in 0..row.len() {
let va = row[a];
if va == 0.0 {
continue;
}
for b in 0..=a {
let vb = row[b];
if vb != 0.0 {
gram[[a, b]] += va * vb;
}
}
}
}
pub fn set_temperature_schedule(
&mut self,
sched: GumbelTemperatureSchedule,
) -> Result<(), String> {
sched.validate()?;
self.assignment
.mode
.set_temperature(sched.current_tau(sched.iter_count))?;
self.temperature_schedule = Some(sched);
Ok(())
}
fn advance_temperature_schedule(&mut self) -> Result<Option<f64>, String> {
let Some(schedule) = self.temperature_schedule.as_mut() else {
return Ok(None);
};
schedule.validate()?;
let tau = schedule.step();
self.assignment.mode.set_temperature(tau)?;
Ok(Some(tau))
}
pub fn n_obs(&self) -> usize {
self.assignment.n_obs()
}
pub fn k_atoms(&self) -> usize {
self.atoms.len()
}
pub fn streaming_plan(&self) -> SaeStreamingPlan {
const HOST_IN_CORE_BYTES: usize = 2 * 1024 * 1024 * 1024;
const CPU_L2_CACHE_BYTES: usize = 1024 * 1024;
const CHUNK_CACHE_MULTIPLE: usize = 8;
let n_obs = self.n_obs();
let total_basis: usize = self.atoms.iter().map(|atom| atom.basis_size()).sum();
let d_max = self
.atoms
.iter()
.map(|atom| atom.latent_dim)
.max()
.unwrap_or(0);
let (budget, chunk_window) = match crate::gpu::runtime::GpuRuntime::global() {
Some(rt) => {
let aggregate_budget: usize = rt
.device_ordinals()
.iter()
.map(|&ord| rt.memory_budget_for(ord))
.sum();
let per_device_budget = aggregate_budget / rt.device_count().max(1);
let window =
(per_device_budget / 16).max(CPU_L2_CACHE_BYTES * CHUNK_CACHE_MULTIPLE);
(aggregate_budget / 4, window)
}
None => (
HOST_IN_CORE_BYTES,
CPU_L2_CACHE_BYTES * CHUNK_CACHE_MULTIPLE,
),
};
sae_streaming_plan_from_budget(
n_obs,
total_basis,
self.k_atoms(),
d_max,
budget,
chunk_window,
)
}
fn validate_analytic_penalty_registry(
&self,
registry: &AnalyticPenaltyRegistry,
) -> Result<(), String> {
let mut row_block_penalty_present = false;
for penalty in ®istry.penalties {
if penalty.tier() != PenaltyTier::Psi {
continue;
}
let is_logit = matches!(
penalty,
AnalyticPenaltyKind::IBPAssignment(_)
| AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
);
if is_logit {
continue;
}
if matches!(penalty, AnalyticPenaltyKind::NuclearNorm(_)) {
continue;
}
if !sae_penalty_is_row_block_supported(penalty) {
return Err(format!(
"SAE-manifold term refuses analytic penalty {:?}: this kind \
has cross-row structure and cannot be expressed in the \
arrow-Schur row layout. Use only row-block-supported \
coord penalties (ARD, BlockOrthogonality, \
Sparsity/TopK/JumpReLU, RowPrecisionPrior, \
ParametricRowPrecisionPrior, ScadMcp, Isometry) on the \
coord latent block, or move the penalty to a non-SAE \
term",
penalty.name()
));
}
row_block_penalty_present = true;
}
if row_block_penalty_present {
let mut dims = self.assignment.coords.iter().map(|c| c.latent_dim());
if let Some(first) = dims.next() {
if let Some(mismatch) = dims.find(|d| *d != first) {
return Err(format!(
"SAE-manifold term refuses row-block analytic penalty: \
atoms have heterogeneous coord latent dims (saw {first} \
and {mismatch}). Row-block penalties (ARD, \
BlockOrthogonality, ...) target the unified \"t\" \
latent block whose declared `d` matches one shape; \
per-atom dispatch with mixed `d_k` would silently \
truncate or expand axes. Configure all atoms with the \
same `atom_dim`, or split the row-block penalty into \
per-atom descriptors keyed to per-atom latent blocks"
));
}
}
}
Ok(())
}
pub fn output_dim(&self) -> usize {
self.atoms[0].output_dim()
}
pub fn beta_dim(&self) -> usize {
let p = self.output_dim();
self.atoms.iter().map(|a| a.basis_size() * p).sum()
}
fn take_border_hbb_workspace(&mut self, border_dim: usize) -> Array2<f64> {
let mut workspace =
std::mem::replace(&mut self.border_hbb_workspace, Array2::<f64>::zeros((0, 0)));
if workspace.dim() != (border_dim, border_dim) {
workspace = Array2::<f64>::zeros((border_dim, border_dim));
} else {
workspace.fill(0.0);
}
workspace
}
fn reclaim_border_hbb_workspace(&mut self, sys: &mut ArrowSchurSystem) {
let workspace = std::mem::replace(&mut sys.hbb, Array2::<f64>::zeros((0, 0)));
self.border_hbb_workspace = workspace;
}
pub fn factored_border_dim(&self) -> usize {
self.atoms.iter().map(|a| a.border_coeff_count()).sum()
}
pub fn grassmann_evidence_dimension(&self) -> usize {
self.atoms
.iter()
.map(|a| a.frame_manifold_dimension())
.sum()
}
pub fn frames_active(&self) -> bool {
self.atoms.iter().any(|a| a.decoder_frame.is_some())
}
pub fn any_frame_active(&self) -> bool {
self.frames_active()
}
pub fn factored_beta_offsets(&self) -> Vec<usize> {
self.factored_border_offsets()
}
pub fn frame_output_matrix(&self, atom_idx: usize) -> Array2<f64> {
let atom = &self.atoms[atom_idx];
match &atom.decoder_frame {
Some(frame) => frame.frame().to_owned(),
None => Array2::<f64>::eye(atom.output_dim()),
}
}
pub fn frame_cross_factor(&self, atom_i: usize, atom_j: usize) -> Array2<f64> {
let ui = self.frame_output_matrix(atom_i);
let uj = self.frame_output_matrix(atom_j);
fast_atb(&ui, &uj)
}
pub fn factored_border_offsets(&self) -> Vec<usize> {
let mut out = Vec::with_capacity(self.k_atoms());
let mut cursor = 0usize;
for atom in &self.atoms {
out.push(cursor);
cursor += atom.border_coeff_count();
}
out
}
pub fn flatten_factored_border(&self) -> Result<Array1<f64>, String> {
let offsets = self.factored_border_offsets();
let mut out = Array1::<f64>::zeros(self.factored_border_dim());
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let off = offsets[atom_idx];
let r = atom.border_frame_rank();
let m = atom.basis_size();
let coords = match atom.factored_coordinates()? {
Some(c) => c,
None => atom.decoder_coefficients.clone(),
};
for basis_col in 0..m {
for j in 0..r {
out[off + basis_col * r + j] = coords[[basis_col, j]];
}
}
}
Ok(out)
}
pub fn scatter_factored_border(&mut self, border: ArrayView1<'_, f64>) -> Result<(), String> {
let expected = self.factored_border_dim();
if border.len() != expected {
return Err(format!(
"SaeManifoldTerm::scatter_factored_border: border length {} must equal \
factored border dim {expected}",
border.len()
));
}
let offsets = self.factored_border_offsets();
for atom_idx in 0..self.atoms.len() {
let off = offsets[atom_idx];
let (r, m, has_frame) = {
let atom = &self.atoms[atom_idx];
(
atom.border_frame_rank(),
atom.basis_size(),
atom.decoder_frame.is_some(),
)
};
let mut coords = Array2::<f64>::zeros((m, r));
for basis_col in 0..m {
for j in 0..r {
coords[[basis_col, j]] = border[off + basis_col * r + j];
}
}
if has_frame {
self.atoms[atom_idx].set_factored_coordinates(coords.view())?;
} else {
self.atoms[atom_idx].decoder_coefficients = coords;
}
}
Ok(())
}
pub fn auto_activate_decoder_frames(&mut self) -> Result<usize, String> {
let mut activated = 0usize;
for atom in &mut self.atoms {
let expected_rank = atom.decoder_frame_activation_rank()?;
match (
expected_rank,
atom.decoder_frame.as_ref().map(GrassmannFrame::rank),
) {
(Some(expected), Some(current)) if expected == current => {
continue;
}
(None, Some(_)) => {
atom.deactivate_decoder_frame();
continue;
}
(None, None) => {
continue;
}
(Some(_), _) => {}
}
if atom.maybe_activate_decoder_frame()?.is_some() {
activated += 1;
}
}
Ok(activated)
}
fn ensure_decoder_frames_active_for_current_decoder(&mut self) -> Result<(), String> {
self.auto_activate_decoder_frames()?;
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let expected_rank = atom.decoder_frame_activation_rank()?;
if let Some(expected_rank) = expected_rank {
match atom.decoder_frame.as_ref() {
Some(frame) if frame.rank() == expected_rank => {}
Some(frame) => {
return Err(format!(
"SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
atom {atom_idx} frame rank {} must equal audited rank {expected_rank}",
frame.rank()
));
}
None => {
return Err(format!(
"SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
atom {atom_idx} has audited rank {expected_rank} but no active frame"
));
}
}
} else if atom.decoder_frame.is_some() {
return Err(format!(
"SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
atom {atom_idx} kept a frame after the full-B predicate won"
));
}
}
Ok(())
}
fn refresh_active_frames_from_data(
&mut self,
target: ArrayView2<'_, f64>,
) -> Result<usize, String> {
let n = self.n_obs();
let p = self.output_dim();
let k_atoms = self.k_atoms();
if n == 0 {
return Ok(0);
}
let mut assignments = Vec::with_capacity(n);
for row in 0..n {
assignments.push(self.assignment.try_assignments_row(row)?);
}
let mut decoded = Array3::<f64>::zeros((n, k_atoms, p));
let mut dbuf = vec![0.0_f64; p];
for row in 0..n {
for atom_idx in 0..k_atoms {
self.atoms[atom_idx].fill_decoded_row(row, &mut dbuf);
for c in 0..p {
decoded[[row, atom_idx, c]] = dbuf[c];
}
}
}
let mut fitted = Array2::<f64>::zeros((n, p));
for row in 0..n {
for atom_idx in 0..k_atoms {
let a = assignments[row][atom_idx];
if a == 0.0 {
continue;
}
for c in 0..p {
fitted[[row, c]] += a * decoded[[row, atom_idx, c]];
}
}
}
let mut refreshed = 0usize;
for atom_idx in 0..k_atoms {
let Some(coords_c) = self.atoms[atom_idx].factored_coordinates()? else {
continue;
};
let r = self.atoms[atom_idx].border_frame_rank();
let m = self.atoms[atom_idx].basis_size();
let mut cross = GrassmannCrossMoment::new(p, r);
let mut targets = Array2::<f64>::zeros((n, p));
let mut rcoords = Array2::<f64>::zeros((n, r));
for row in 0..n {
let a = assignments[row][atom_idx];
for c in 0..p {
let e = target[[row, c]] - fitted[[row, c]] + a * decoded[[row, atom_idx, c]];
targets[[row, c]] = a * e;
}
for j in 0..r {
let mut acc = 0.0_f64;
for basis_col in 0..m {
acc += self.atoms[atom_idx].basis_values[[row, basis_col]]
* coords_c[[basis_col, j]];
}
rcoords[[row, j]] = a * acc;
}
}
cross.accumulate(targets.view(), rcoords.view())?;
if cross.moment().iter().all(|&v| v == 0.0) {
continue;
}
self.atoms[atom_idx].refresh_frame_from_cross_moment(cross.moment())?;
refreshed += 1;
}
Ok(refreshed)
}
pub fn beta_offsets(&self) -> Vec<usize> {
let p = self.output_dim();
let mut out = Vec::with_capacity(self.k_atoms());
let mut cursor = 0usize;
for atom in &self.atoms {
out.push(cursor);
cursor += atom.basis_size() * p;
}
out
}
pub fn beta_block_offsets(&self) -> Arc<[std::ops::Range<usize>]> {
let p = self.output_dim();
let mut ranges: Vec<std::ops::Range<usize>> = Vec::with_capacity(self.k_atoms());
let mut cursor = 0usize;
for atom in &self.atoms {
let width = atom.basis_size() * p;
ranges.push(cursor..cursor + width);
cursor += width;
}
Arc::from(ranges.into_boxed_slice())
}
fn sparse_active_plan(&self) -> Option<(usize, f64)> {
const BYTES_PER_F64: usize = 8;
const HOST_GRAM_BYTES: usize = 2 * 1024 * 1024 * 1024; const RELATIVE_CUTOFF: f64 = 1.0e-3;
let k_atoms = self.k_atoms();
if k_atoms <= 1 {
return None;
}
if !self.ext_coord_manifold().is_euclidean() {
return None;
}
let p = self.output_dim();
let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
let dense_gram_bytes = m_total
.saturating_mul(m_total)
.saturating_mul(BYTES_PER_F64);
let budget = match crate::gpu::runtime::GpuRuntime::global() {
Some(rt) => {
let aggregate: usize = rt
.device_ordinals()
.iter()
.map(|&ord| rt.memory_budget_for(ord))
.sum();
aggregate / 4
}
None => HOST_GRAM_BYTES,
};
if dense_gram_bytes <= budget {
return None;
}
let m_atom = (m_total as f64 / k_atoms as f64).max(1.0);
let max_active_basis = ((budget as f64 / BYTES_PER_F64 as f64).sqrt() / m_atom).floor();
let k_active_cap = (max_active_basis as usize).clamp(1, k_atoms);
if p == 0 {
return None;
}
Some((k_active_cap, RELATIVE_CUTOFF))
}
pub fn flatten_beta(&self) -> Array1<f64> {
let p = self.output_dim();
let offsets = self.beta_offsets();
let mut out = Array1::<f64>::zeros(self.beta_dim());
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
let off = offsets[atom_idx];
for basis_col in 0..m {
for out_col in 0..p {
out[off + basis_col * p + out_col] =
atom.decoder_coefficients[[basis_col, out_col]];
}
}
}
out
}
pub fn set_flat_beta(&mut self, beta: ArrayView1<'_, f64>) -> Result<(), String> {
if beta.len() != self.beta_dim() {
return Err(format!(
"set_flat_beta: beta length {} != expected {}",
beta.len(),
self.beta_dim()
));
}
let p = self.output_dim();
let offsets = self.beta_offsets();
for (atom_idx, atom) in self.atoms.iter_mut().enumerate() {
let m = atom.basis_size();
let off = offsets[atom_idx];
for basis_col in 0..m {
for out_col in 0..p {
atom.decoder_coefficients[[basis_col, out_col]] =
beta[off + basis_col * p + out_col];
}
}
}
Ok(())
}
pub fn fitted(&self) -> Array2<f64> {
self.try_fitted().expect("assignment logits must be finite")
}
pub fn try_fitted(&self) -> Result<Array2<f64>, String> {
let n = self.n_obs();
let p = self.output_dim();
let k_atoms = self.k_atoms();
let mut out = Array2::<f64>::zeros((n, p));
let mut g_buf = vec![0.0_f64; p];
for row in 0..n {
let a = self.assignment.try_assignments_row(row)?;
for atom_idx in 0..k_atoms {
self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
let a_k = a[atom_idx];
let mut out_row = out.row_mut(row);
for out_col in 0..p {
out_row[out_col] += a_k * g_buf[out_col];
}
}
}
Ok(out)
}
pub fn loss(
&self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
) -> Result<SaeManifoldLoss, String> {
self.loss_scaled(target, rho, 1.0)
}
pub fn loss_scaled(
&self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
penalty_scale: f64,
) -> Result<SaeManifoldLoss, String> {
if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
return Err(format!(
"SaeManifoldTerm::loss_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
));
}
if target.dim() != (self.n_obs(), self.output_dim()) {
return Err(format!(
"SaeManifoldTerm::loss: Z must be ({}, {}); got {:?}",
self.n_obs(),
self.output_dim(),
target.dim()
));
}
let fitted = self.try_fitted()?;
let mut data_fit = 0.0_f64;
let whitens = self
.row_metric
.as_ref()
.is_some_and(|metric| metric.whitens_likelihood());
let mut resid_row = ndarray::Array1::<f64>::zeros(target.ncols());
let row_loss_w = self.row_loss_weights.as_deref();
for row in 0..target.nrows() {
let w_row = row_loss_w.map_or(1.0, |w| w[row]);
for out_col in 0..target.ncols() {
resid_row[out_col] = target[[row, out_col]] - fitted[[row, out_col]];
}
match self.row_metric.as_ref() {
Some(metric) if whitens => {
for w in metric.whiten_residual_row(row, resid_row.view()) {
data_fit += 0.5 * w_row * w * w;
}
}
_ => {
for &r in resid_row.iter() {
data_fit += 0.5 * w_row * r * r;
}
}
}
}
let assignment_sparsity = assignment_prior_value(&self.assignment, rho);
let smoothness = penalty_scale * self.decoder_smoothness_value(rho.lambda_smooth());
let ard = self.ard_value(rho)?;
Ok(SaeManifoldLoss {
data_fit,
assignment_sparsity,
smoothness,
ard,
})
}
pub fn analytic_penalty_value_total(
&self,
registry: &AnalyticPenaltyRegistry,
penalty_scale: f64,
) -> Result<f64, ArrowSchurError> {
if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"SaeManifoldTerm::analytic_penalty_value_total: penalty_scale must be finite \
and positive; got {penalty_scale}"
),
});
}
let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
let layout = registry.rho_layout();
let logits_flat = flat_logits(self.assignment.logits.view());
let beta = self.flatten_beta();
let mut value = 0.0_f64;
for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(layout.iter()) {
let rho_local = rho_global.slice(s![rho_slice.clone()]);
if matches!(penalty, AnalyticPenaltyKind::Ard(_)) {
continue;
}
match tier {
PenaltyTier::Psi => {
if matches!(
penalty,
AnalyticPenaltyKind::IBPAssignment(_)
| AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
) {
value += penalty.value(logits_flat.view(), rho_local);
} else if let AnalyticPenaltyKind::NuclearNorm(base) = penalty {
for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
value += penalty_scale
* per_atom.value(beta.slice(s![start..end]), rho_local);
}
} else {
if !sae_penalty_is_row_block_supported(penalty) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"validate_analytic_penalty_registry should have refused \
non-row-block Psi-tier penalty {:?} (registry layout name \
{name:?})",
penalty.name()
),
});
}
for atom_idx in 0..self.k_atoms() {
let coord = &self.assignment.coords[atom_idx];
if let AnalyticPenaltyKind::Isometry(iso) = penalty {
let corrected_kind =
self.corrected_isometry_penalty(iso, atom_idx, coord)?;
value += corrected_kind.value(coord.as_flat().view(), rho_local);
} else if sae_coord_penalty_is_origin_anchored_magnitude(penalty) {
match sae_coord_penalty_euclidean_restriction(coord) {
Some((_axes, compacted)) => {
value += penalty.value(compacted.view(), rho_local);
}
None => {
value += penalty.value(coord.as_flat().view(), rho_local);
}
}
} else {
value += penalty.value(coord.as_flat().view(), rho_local);
}
}
}
}
PenaltyTier::Beta => {
if let AnalyticPenaltyKind::DecoderIncoherence(base) = penalty {
if let Some(per_fit) = self.live_decoder_incoherence_penalty(base) {
value += penalty_scale * per_fit.value(beta.view(), rho_local);
}
} else if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
if start < end {
value += penalty_scale * per_atom.value(beta.view(), rho_local);
}
}
} else {
value += penalty_scale * penalty.value(beta.view(), rho_local);
}
}
PenaltyTier::Rho => {}
}
}
Ok(value)
}
pub fn analytic_decoder_penalty_value_total(
&self,
registry: &AnalyticPenaltyRegistry,
) -> Result<f64, ArrowSchurError> {
let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
let layout = registry.rho_layout();
let beta = self.flatten_beta();
let mut value = 0.0_f64;
for (penalty, (rho_slice, _tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
let rho_local = rho_global.slice(s![rho_slice.clone()]);
match penalty {
AnalyticPenaltyKind::DecoderIncoherence(base) => {
if let Some(per_fit) = self.live_decoder_incoherence_penalty(base) {
value += per_fit.value(beta.view(), rho_local);
}
}
AnalyticPenaltyKind::MechanismSparsity(base) => {
for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
if start < end {
value += per_atom.value(beta.view(), rho_local);
}
}
}
AnalyticPenaltyKind::NuclearNorm(base) => {
for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
value += per_atom.value(beta.slice(s![start..end]), rho_local);
}
}
_ => {}
}
}
Ok(value)
}
pub fn isometry_penalty_value_total(
&self,
registry: &AnalyticPenaltyRegistry,
) -> Result<f64, ArrowSchurError> {
let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
let layout = registry.rho_layout();
let mut value = 0.0_f64;
for (penalty, (rho_slice, _tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
if let AnalyticPenaltyKind::Isometry(iso) = penalty {
let rho_local = rho_global.slice(s![rho_slice.clone()]);
for atom_idx in 0..self.k_atoms() {
let coord = &self.assignment.coords[atom_idx];
let corrected_kind = self.corrected_isometry_penalty(iso, atom_idx, coord)?;
value += corrected_kind.value(coord.as_flat().view(), rho_local);
}
}
}
Ok(value)
}
pub fn reml_extra_penalty_value_total(
&self,
registry: &AnalyticPenaltyRegistry,
) -> Result<f64, ArrowSchurError> {
Ok(self.analytic_decoder_penalty_value_total(registry)?
+ self.isometry_penalty_value_total(registry)?)
}
pub fn penalized_objective_total(
&self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
registry: Option<&AnalyticPenaltyRegistry>,
penalty_scale: f64,
) -> Result<f64, String> {
let mut total = self.loss_scaled(target, rho, penalty_scale)?.total();
if let Some(analytic_registry) = registry {
total += self
.analytic_penalty_value_total(analytic_registry, penalty_scale)
.map_err(|err| format!("SaeManifoldTerm::penalized_objective_total: {err}"))?;
}
Ok(total)
}
fn decoder_smoothness_value(&self, lambda_smooth: f64) -> f64 {
let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
.atoms
.iter()
.map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
.collect();
let sb_all = batched_smooth_sb(&sb_inputs, false);
let mut acc = 0.0;
for (atom, sb) in self.atoms.iter().zip(sb_all.iter()) {
acc += 0.5 * lambda_smooth * (&atom.decoder_coefficients * sb).sum();
}
acc
}
fn ard_value(&self, rho: &SaeManifoldRho) -> Result<f64, String> {
if rho.log_ard.len() != self.k_atoms() {
return Err(format!(
"ARD rho has {} atoms but term has {}",
rho.log_ard.len(),
self.k_atoms()
));
}
let n = self.n_obs();
let mut acc = 0.0;
for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
let d = coord.latent_dim();
if rho.log_ard[atom_idx].is_empty() {
continue;
}
if rho.log_ard[atom_idx].len() != d {
return Err(format!(
"ARD rho atom {atom_idx} has len {} but atom dim is {d}",
rho.log_ard[atom_idx].len()
));
}
let periods = coord.effective_axis_periods();
for axis in 0..d {
let log_alpha = rho.log_ard[atom_idx][axis];
let alpha = SaeManifoldRho::stable_exp_strength(log_alpha);
let period = periods[axis];
let mut energy = 0.0;
for row in 0..n {
let v = coord.row(row)[axis];
energy += ArdAxisPrior::eval(alpha, v, period).value;
}
match period {
None => {
acc += energy - 0.5 * (n as f64) * log_alpha;
}
Some(p) => {
let kappa = std::f64::consts::TAU / p;
let eta = alpha / (kappa * kappa);
let log_i0 = bessel_i0(eta).ln();
acc += energy + (n as f64) * (-eta + log_i0);
}
}
}
}
Ok(acc)
}
pub fn assemble_arrow_schur(
&mut self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
) -> Result<ArrowSchurSystem, String> {
self.assemble_arrow_schur_scaled(target, rho, analytic_penalties, 1.0)
}
pub fn assemble_arrow_schur_scaled(
&mut self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
penalty_scale: f64,
) -> Result<ArrowSchurSystem, String> {
self.assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
target,
rho,
analytic_penalties,
penalty_scale,
SAE_DENSE_BETA_PENALTY_PROBE_MAX_DIM,
)
}
fn assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
&mut self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
penalty_scale: f64,
dense_beta_penalty_probe_max_dim: usize,
) -> Result<ArrowSchurSystem, String> {
if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
return Err(format!(
"SaeManifoldTerm::assemble_arrow_schur_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
));
}
if target.dim() != (self.n_obs(), self.output_dim()) {
return Err(format!(
"SaeManifoldTerm::assemble_arrow_schur: Z must be ({}, {}); got {:?}",
self.n_obs(),
self.output_dim(),
target.dim()
));
}
if rho.log_ard.len() != self.k_atoms() {
return Err(format!(
"SaeManifoldTerm::assemble_arrow_schur: log_ard length {} != K {}",
rho.log_ard.len(),
self.k_atoms()
));
}
for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
let ard_len = rho.log_ard[atom_idx].len();
let d = coord.latent_dim();
if ard_len != 0 && ard_len != d {
return Err(format!(
"SaeManifoldTerm::assemble_arrow_schur: log_ard atom {atom_idx} \
has len {ard_len}; expected 0 (disabled) or atom dim {d}"
));
}
}
for atom in &mut self.atoms {
atom.refresh_intrinsic_smooth_penalty();
}
let n = self.n_obs();
let p = self.output_dim();
let k_atoms = self.k_atoms();
let assignment_dim = self.assignment.assignment_coord_dim();
let q = self.assignment.row_block_dim();
let beta_dim = self.beta_dim();
let frame_projection = FrameProjection::new(self);
let beta_offsets = frame_projection.beta_offsets.clone();
let coord_offsets = self.assignment.coord_offsets();
let lambda_smooth = rho.lambda_smooth() * penalty_scale;
let (assignment_grad, assignment_hdiag) =
assignment_prior_grad_hdiag(&self.assignment, rho)?;
let mut smooth_ops: Vec<Arc<dyn BetaPenaltyOp>> = Vec::with_capacity(self.atoms.len());
let mut smooth_scaled_s: Vec<Array2<f64>> = Vec::with_capacity(self.atoms.len());
let mut smooth_grad_gb = vec![0.0_f64; beta_dim];
let sym_sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
.atoms
.iter()
.map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
.collect();
let sym_sb_all = batched_smooth_sb(&sym_sb_inputs, true);
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
let off = beta_offsets[atom_idx];
let mut scaled_s = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
let s_ij = 0.5 * (atom.smooth_penalty[[i, j]] + atom.smooth_penalty[[j, i]]);
scaled_s[[i, j]] = lambda_smooth * s_ij;
}
}
let sb = &sym_sb_all[atom_idx] * lambda_smooth;
for out_col in 0..p {
for i in 0..m {
let beta_i = off + i * p + out_col;
smooth_grad_gb[beta_i] += sb[[i, out_col]];
}
}
smooth_ops.push(Arc::new(IdentityRightKroneckerPenaltyOp {
factor_a: scaled_s.clone(),
p,
global_offset: off,
k: beta_dim,
}));
smooth_scaled_s.push(scaled_s);
}
let coord_dims: Vec<usize> = self
.assignment
.coords
.iter()
.map(|c| c.latent_dim())
.collect();
let row_layout: Option<SaeRowLayout> = match self.assignment.mode {
AssignmentMode::JumpReLU {
threshold,
temperature,
} => Some(SaeRowLayout::from_jumprelu(
n,
k_atoms,
threshold,
temperature,
&self.assignment.logits,
coord_dims.clone(),
self.assignment.coord_offsets(),
)),
AssignmentMode::Softmax { .. } => None,
AssignmentMode::IBPMap { .. } => {
match self.sparse_active_plan() {
Some((k_active_cap, relative_cutoff)) => {
let mut assignments_all = Vec::with_capacity(n);
for row in 0..n {
assignments_all.push(self.assignment.try_assignments_row(row)?);
}
let peak = assignments_all
.iter()
.flat_map(|a| a.iter())
.fold(0.0_f64, |m, &v| m.max(v.abs()));
let cutoff = relative_cutoff * peak;
Some(SaeRowLayout::from_dense_weights(
&assignments_all,
k_active_cap,
cutoff,
coord_dims.clone(),
self.assignment.coord_offsets(),
))
}
None => None,
}
}
};
let whitens_likelihood = self
.row_metric
.as_ref()
.is_some_and(|metric| metric.whitens_likelihood());
let frames_engaged = self.any_frame_active() && !whitens_likelihood;
let dense_beta_curvature = !(frames_engaged && beta_dim > dense_beta_penalty_probe_max_dim);
let row_htbeta_dim = if frames_engaged {
self.factored_border_dim()
} else {
beta_dim
};
let mut sys = if let Some(ref layout) = row_layout {
let per_row_dims: Vec<usize> = (0..n).map(|row| layout.row_q_active(row)).collect();
if dense_beta_curvature {
let hbb_workspace = self.take_border_hbb_workspace(beta_dim);
ArrowSchurSystem::new_with_per_row_dims_and_hbb_and_htbeta_cols(
per_row_dims,
beta_dim,
hbb_workspace,
row_htbeta_dim,
)
} else {
self.border_hbb_workspace = Array2::<f64>::zeros((0, 0));
ArrowSchurSystem::new_with_per_row_dims_empty_hbb_and_htbeta_cols(
per_row_dims,
beta_dim,
row_htbeta_dim,
)
}
} else if dense_beta_curvature {
let hbb_workspace = self.take_border_hbb_workspace(beta_dim);
ArrowSchurSystem::new_with_hbb_and_htbeta_cols(
n,
q,
beta_dim,
hbb_workspace,
row_htbeta_dim,
)
} else {
self.border_hbb_workspace = Array2::<f64>::zeros((0, 0));
ArrowSchurSystem::new_with_empty_hbb_and_htbeta_cols(n, q, beta_dim, row_htbeta_dim)
};
for (i, g) in smooth_grad_gb.iter().enumerate() {
sys.gb[i] += g;
}
let mut decoded = Array2::<f64>::zeros((k_atoms, p));
let mut dg_buf = vec![0.0_f64; p];
let mut fitted = Array1::<f64>::zeros(p);
let mut error = Array1::<f64>::zeros(p);
let w_dim = match self.row_metric.as_ref() {
Some(metric) if whitens_likelihood => metric.metric_rank(),
_ => p,
};
let mut error_white = vec![0.0_f64; w_dim];
let mut error_metric = Array1::<f64>::zeros(p);
let mut jac_white = vec![0.0_f64; q * w_dim.max(p)];
let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
let mu_offsets: Vec<usize> = beta_offsets.iter().map(|&off| off / p).collect();
let mut g_blocks: std::collections::BTreeMap<(usize, usize), Array2<f64>> =
std::collections::BTreeMap::new();
let ibp_prior_vec = match self.assignment.mode {
AssignmentMode::IBPMap { alpha, .. } => {
Some(ibp_stick_breaking_prior(k_atoms, alpha).to_vec())
}
_ => None,
};
let ibp_prior_slice = ibp_prior_vec.as_deref();
let row_loss_w = self.row_loss_weights.as_deref();
let mut decoded_scratch = vec![0.0_f64; p];
let mut kron_a_phi: Vec<Vec<(usize, f64)>> = Vec::with_capacity(n);
let mut kron_jac: Vec<Vec<f64>> = Vec::with_capacity(n);
let all_atoms_index: Vec<usize> = (0..k_atoms).collect();
let ard_axis_periods: Vec<Vec<Option<f64>>> = self
.assignment
.coords
.iter()
.map(|coord| coord.effective_axis_periods())
.collect();
for row in 0..n {
let assignments = self.assignment.try_assignments_row(row)?;
fitted.fill(0.0);
let row_active_owned: Option<&[usize]> =
row_layout.as_ref().map(|l| l.active_atoms[row].as_slice());
match row_active_owned {
Some(active) => {
for &atom_idx in active {
let a_k = assignments[atom_idx];
self.atoms[atom_idx].fill_decoded_row(row, &mut decoded_scratch);
for out_col in 0..p {
decoded[[atom_idx, out_col]] = decoded_scratch[out_col];
fitted[out_col] += a_k * decoded_scratch[out_col];
}
}
}
None => {
for atom_idx in 0..k_atoms {
let a_k = assignments[atom_idx];
self.atoms[atom_idx].fill_decoded_row(row, &mut decoded_scratch);
for out_col in 0..p {
decoded[[atom_idx, out_col]] = decoded_scratch[out_col];
fitted[out_col] += a_k * decoded_scratch[out_col];
}
}
}
}
for out_col in 0..p {
error[out_col] = fitted[out_col] - target[[row, out_col]];
}
let sqrt_row_w = row_loss_w.map_or(1.0, |w| w[row].sqrt());
if sqrt_row_w != 1.0 {
for out_col in 0..p {
error[out_col] *= sqrt_row_w;
}
}
match self.row_metric.as_ref() {
Some(metric) if whitens_likelihood => {
let wr = metric.whiten_residual_row(row, error.view());
for (slot, &v) in error_white.iter_mut().zip(wr.iter()) {
*slot = v;
}
let mr = metric.apply_metric_row(row, error.view());
for (slot, &v) in error_metric.iter_mut().zip(mr.iter()) {
*slot = v;
}
}
_ => {
for out_col in 0..p {
error_white[out_col] = error[out_col];
error_metric[out_col] = error[out_col];
}
}
}
let (q_row, mut local_jac_row) = if let Some(ref layout) = row_layout {
let active = &layout.active_atoms[row];
let starts = &layout.coord_starts[row];
let q_active = layout.row_q_active(row);
let mut jac_compact = Array2::<f64>::zeros((q_active, p));
let logits_row = self.assignment.logits.row(row);
for (j, &k) in active.iter().enumerate() {
fill_active_atom_logit_jvp(
ActiveAtomLogitJvp {
mode: self.assignment.mode,
k,
logit_k: logits_row[k],
a_k: assignments[k],
decoded_k: decoded.row(k),
fitted: fitted.view(),
ibp_prior: ibp_prior_slice,
compact_index: j,
},
&mut jac_compact,
);
}
for (j, &k) in active.iter().enumerate() {
let d = self.atoms[k].latent_dim;
let a_k = assignments[k];
let coord_start = starts[j];
for axis in 0..d {
self.atoms[k].fill_decoded_derivative_row(row, axis, &mut dg_buf);
for out_col in 0..p {
jac_compact[[coord_start + axis, out_col]] = a_k * dg_buf[out_col];
}
}
}
(q_active, jac_compact)
} else {
let mut jac_row = Array2::<f64>::zeros((q, p));
fill_assignment_logit_jvp_rows(
self.assignment.mode,
self.assignment.logits.row(row),
assignments.view(),
decoded.view(),
fitted.view(),
ibp_prior_slice,
&mut jac_row,
);
for atom_idx in 0..k_atoms {
let d = self.atoms[atom_idx].latent_dim;
let off = coord_offsets[atom_idx];
let a_k = assignments[atom_idx];
for axis in 0..d {
self.atoms[atom_idx].fill_decoded_derivative_row(row, axis, &mut dg_buf);
for out_col in 0..p {
jac_row[[off + axis, out_col]] = a_k * dg_buf[out_col];
}
}
}
(q, jac_row)
};
if sqrt_row_w != 1.0 {
for a in 0..q_row {
for out_col in 0..p {
local_jac_row[[a, out_col]] *= sqrt_row_w;
}
}
}
if whitens_likelihood {
if let Some(metric) = self.row_metric.as_ref() {
for a in 0..q_row {
for k in 0..w_dim {
let mut acc = 0.0;
for out_col in 0..p {
acc += metric.factor_entry(row, out_col, k)
* local_jac_row[[a, out_col]];
}
jac_white[a * w_dim + k] = acc;
}
}
}
} else {
for a in 0..q_row {
for out_col in 0..p {
jac_white[a * w_dim + out_col] = local_jac_row[[a, out_col]];
}
}
}
let mut block = ArrowRowBlock::new(q_row, row_htbeta_dim);
for a in 0..q_row {
let mut g = 0.0;
for k in 0..w_dim {
g += jac_white[a * w_dim + k] * error_white[k];
}
block.gt[a] += g;
for b in 0..q_row {
let mut h = 0.0;
for k in 0..w_dim {
h += jac_white[a * w_dim + k] * jac_white[b * w_dim + k];
}
block.htt[[a, b]] += h;
}
}
let assignment_base = row * k_atoms;
if let Some(ref layout) = row_layout {
let active = &layout.active_atoms[row];
for (j, &k) in active.iter().enumerate() {
block.gt[j] += assignment_grad[assignment_base + k];
block.htt[[j, j]] += assignment_hdiag[assignment_base + k];
}
} else {
for free_idx in 0..assignment_dim {
block.gt[free_idx] += assignment_grad[assignment_base + free_idx];
block.htt[[free_idx, free_idx]] += assignment_hdiag[assignment_base + free_idx];
}
}
if let Some(ref layout) = row_layout {
let active = &layout.active_atoms[row];
let starts = &layout.coord_starts[row];
for (j, &k) in active.iter().enumerate() {
let coord = &self.assignment.coords[k];
let d = coord.latent_dim();
if rho.log_ard[k].is_empty() {
continue;
}
if rho.log_ard[k].len() != d {
return Err(format!(
"ARD rho atom {k} has len {} but atom dim is {d}",
rho.log_ard[k].len()
));
}
let row_t = coord.row(row);
let periods = &ard_axis_periods[k];
for axis in 0..d {
let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
let prior = ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
block.gt[starts[j] + axis] += prior.grad;
block.htt[[starts[j] + axis, starts[j] + axis]] += prior.hess.max(0.0);
}
}
} else {
for atom_idx in 0..k_atoms {
let coord = &self.assignment.coords[atom_idx];
let d = coord.latent_dim();
if rho.log_ard[atom_idx].is_empty() {
continue;
}
if rho.log_ard[atom_idx].len() != d {
return Err(format!(
"ARD rho atom {atom_idx} has len {} but atom dim is {d}",
rho.log_ard[atom_idx].len()
));
}
let off = coord_offsets[atom_idx];
let row_t = coord.row(row);
let periods = &ard_axis_periods[atom_idx];
for axis in 0..d {
let alpha =
SaeManifoldRho::stable_exp_strength(rho.log_ard[atom_idx][axis]);
let prior = ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
block.gt[off + axis] += prior.grad;
block.htt[[off + axis, off + axis]] += prior.hess.max(0.0);
}
}
}
let row_active: &[usize] = match row_layout {
Some(ref layout) => layout.active_atoms[row].as_slice(),
None => &all_atoms_index,
};
let mut a_phi: Vec<(usize, f64)> = Vec::with_capacity(row_active.len() * 4);
let mut weighted_phi: Vec<(usize, Vec<f64>)> = Vec::with_capacity(row_active.len());
for &atom_idx in row_active {
let atom = &self.atoms[atom_idx];
let atom_beta_off = beta_offsets[atom_idx];
let m = atom.basis_size();
let a_k = assignments[atom_idx];
let mut wphi = Vec::with_capacity(m);
for basis_col in 0..m {
let phi = atom.basis_values[[row, basis_col]];
let w = a_k * phi * sqrt_row_w;
a_phi.push((atom_beta_off + basis_col * p, w));
wphi.push(w);
}
weighted_phi.push((atom_idx, wphi));
}
for &(beta_base_i, j_beta_i) in a_phi.iter() {
if j_beta_i == 0.0 {
continue;
}
for out_col in 0..p {
sys.gb[beta_base_i + out_col] += j_beta_i * error_metric[out_col];
}
}
for ai in 0..weighted_phi.len() {
let (atom_i, ref wphi_i) = weighted_phi[ai];
let m_i = wphi_i.len();
for aj in 0..weighted_phi.len() {
let (atom_j, ref wphi_j) = weighted_phi[aj];
let m_j = wphi_j.len();
let blk = g_blocks
.entry((atom_i, atom_j))
.or_insert_with(|| Array2::<f64>::zeros((m_i, m_j)));
for li in 0..m_i {
let wi = wphi_i[li];
if wi == 0.0 {
continue;
}
for lj in 0..m_j {
blk[[li, lj]] += wi * wphi_j[lj];
}
}
}
}
kron_a_phi.push(a_phi);
let mut jac_flat = vec![0.0_f64; q_row * p];
for c in 0..q_row {
for j in 0..p {
jac_flat[c * p + j] = local_jac_row[[c, j]];
}
}
kron_jac.push(jac_flat);
sys.rows[row] = block;
}
if row_layout.is_none() {
let raw_gt_rows: Vec<Array1<f64>> = sys.rows.iter().map(|row| row.gt.clone()).collect();
self.apply_sae_riemannian_geometry(&mut sys);
let manifold = self.ext_coord_manifold();
if !manifold.is_euclidean() {
let ext = self.ext_coord_matrix();
let mut t_buf = vec![0.0_f64; q];
let mut col_buf = Array1::<f64>::zeros(q);
for row_idx in 0..n {
let ext_row = ext.row(row_idx);
for (slot, &v) in t_buf.iter_mut().zip(ext_row.iter()) {
*slot = v;
}
let t_i = ArrayView1::from(t_buf.as_slice());
let raw_gt = raw_gt_rows[row_idx].view();
let jac_flat = &mut kron_jac[row_idx];
let q_row = jac_flat.len() / p;
for j in 0..p {
for c in 0..q_row {
col_buf[c] = jac_flat[c * p + j];
}
let projected_col = manifold.project_vector_to_gradient_tangent(
t_i,
raw_gt.slice(ndarray::s![..q_row]),
col_buf.slice(ndarray::s![..q_row]),
);
for c in 0..q_row {
jac_flat[c * p + j] = projected_col[c];
}
}
}
}
}
if frames_engaged {
let kron = Arc::new(SaeFrameKroneckerRows::new(
p,
frame_projection.clone(),
kron_a_phi,
kron_jac,
)?);
let kron_t = Arc::clone(&kron);
let p_dim = p;
sys.set_row_htbeta_operator(
move |row_idx, x, out| {
let out_slice = out.as_slice_mut().expect("out is always standard-layout");
let mut u_p = vec![0.0_f64; p_dim];
if let Some(xs) = x.as_slice() {
kron.apply_jbeta_factored(row_idx, xs, &mut u_p);
} else {
let x_vec: Vec<f64> = x.iter().copied().collect();
kron.apply_jbeta_factored(row_idx, &x_vec, &mut u_p);
}
kron.apply_l(row_idx, &u_p, out_slice);
},
move |row_idx, v, out| {
let out_slice = out.as_slice_mut().expect("out is always standard-layout");
let mut u_p = vec![0.0_f64; p_dim];
if let Some(vs) = v.as_slice() {
kron_t.apply_l_t(row_idx, vs, &mut u_p);
} else {
let v_vec: Vec<f64> = v.iter().copied().collect();
kron_t.apply_l_t(row_idx, &v_vec, &mut u_p);
}
kron_t.scatter_jbeta_factored_t(row_idx, &u_p, out_slice);
},
);
} else {
let kron = Arc::new(SaeKroneckerRows::new(p, kron_a_phi, kron_jac));
let kron_t = Arc::clone(&kron);
let p_dim = p;
sys.set_row_htbeta_operator(
move |row_idx, x, out| {
let out_slice = out.as_slice_mut().expect("out is always standard-layout");
let mut u_p = vec![0.0_f64; p_dim];
if let Some(xs) = x.as_slice() {
kron.apply_jbeta(row_idx, xs, &mut u_p);
} else {
let x_vec: Vec<f64> = x.iter().copied().collect();
kron.apply_jbeta(row_idx, &x_vec, &mut u_p);
}
kron.apply_l(row_idx, &u_p, out_slice);
},
move |row_idx, v, out| {
let out_slice = out.as_slice_mut().expect("out is always standard-layout");
let mut u_p = vec![0.0_f64; p_dim];
if let Some(vs) = v.as_slice() {
kron_t.apply_l_t(row_idx, vs, &mut u_p);
} else {
let v_vec: Vec<f64> = v.iter().copied().collect();
kron_t.apply_l_t(row_idx, &v_vec, &mut u_p);
}
kron_t.scatter_jbeta_t(row_idx, &u_p, out_slice);
},
);
}
let mut beta_penalty_assembly = SaeBetaPenaltyAssembly::default();
let factored_row_projection = if frames_engaged && analytic_penalties.is_some() {
Some(&frame_projection)
} else {
None
};
if let Some(registry) = analytic_penalties {
self.validate_analytic_penalty_registry(registry)
.map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
beta_penalty_assembly = self
.add_sae_analytic_penalty_contributions(
&mut sys,
registry,
penalty_scale,
row_layout.as_ref(),
dense_beta_curvature,
factored_row_projection,
)
.map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
}
if frames_engaged {
let off_c = &frame_projection.border_offsets;
let ranks = &frame_projection.ranks;
let basis_sizes = &frame_projection.basis_sizes;
let border_dim = frame_projection.border_dim();
let gb_c = frame_projection.project_border_vec(sys.gb.view());
let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::with_capacity(g_blocks.len());
for ((atom_i, atom_j), data) in g_blocks.into_iter() {
if data.iter().all(|&v| v == 0.0) {
continue;
}
let w = self.frame_cross_factor(atom_i, atom_j);
frame_blocks.push(FactoredFrameGBlock {
atom_i,
atom_j,
g: data,
w,
});
}
let data_op =
FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks)?;
let mut ops: Vec<Arc<dyn BetaPenaltyOp>> = Vec::with_capacity(self.atoms.len() + 2);
for k in 0..self.atoms.len() {
let r = ranks[k];
ops.push(Arc::new(IdentityRightKroneckerPenaltyOp {
factor_a: smooth_scaled_s[k].clone(),
p: r,
global_offset: off_c[k],
k: border_dim,
}));
}
ops.push(Arc::new(data_op));
if beta_penalty_assembly.dense_written {
let hbb_c =
self.project_dense_penalty_to_factored(sys.hbb.view(), &frame_projection);
ops.push(Arc::new(DensePenaltyOp(hbb_c)));
} else if beta_penalty_assembly.deferred_factored {
let registry =
analytic_penalties.expect("deferred beta curvature requires registry");
let hbb_c = self.build_factored_beta_penalty_curvature(
registry,
penalty_scale,
&frame_projection,
);
ops.push(Arc::new(DensePenaltyOp(hbb_c)));
}
sys.k = border_dim;
sys.gb = gb_c;
self.reclaim_border_hbb_workspace(&mut sys);
let mut block_ranges: Vec<std::ops::Range<usize>> =
Vec::with_capacity(self.atoms.len());
for k in 0..self.atoms.len() {
let start = off_c[k];
block_ranges.push(start..start + basis_sizes[k] * ranks[k]);
}
sys.set_block_offsets(Arc::from(block_ranges.into_boxed_slice()));
sys.set_penalty_op(Arc::new(CompositePenaltyOp { k: border_dim, ops }));
} else {
sys.set_block_offsets(self.beta_block_offsets());
let g_sparse_blocks: Vec<SparseGBlock> = g_blocks
.into_iter()
.filter_map(|((atom_i, atom_j), data)| {
if data.iter().all(|&v| v == 0.0) {
None
} else {
Some(SparseGBlock {
row_off: mu_offsets[atom_i],
col_off: mu_offsets[atom_j],
data,
})
}
})
.collect();
let mut ops: Vec<Arc<dyn BetaPenaltyOp>> = smooth_ops;
ops.push(Arc::new(SparseBlockKroneckerPenaltyOp {
p,
dim_a: m_total,
k: beta_dim,
blocks: g_sparse_blocks,
}));
if beta_penalty_assembly.dense_written {
ops.push(Arc::new(DensePenaltyOp(sys.hbb.clone())));
}
sys.set_penalty_op(Arc::new(CompositePenaltyOp { k: beta_dim, ops }));
self.reclaim_border_hbb_workspace(&mut sys);
}
self.last_row_layout = row_layout;
self.last_frames_active = frames_engaged;
Ok(sys)
}
fn project_dense_penalty_to_factored(
&self,
hbb: ArrayView2<'_, f64>,
projection: &FrameProjection,
) -> Array2<f64> {
projection.project_block(hbb)
}
fn build_factored_beta_penalty_curvature(
&self,
registry: &AnalyticPenaltyRegistry,
penalty_scale: f64,
projection: &FrameProjection,
) -> Array2<f64> {
let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
let layout = registry.rho_layout();
let target_beta = self.flatten_beta();
let mut hbb_c = Array2::<f64>::zeros((projection.border_dim(), projection.border_dim()));
for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
if matches!(penalty, AnalyticPenaltyKind::Ard(_)) {
continue;
}
let rho_local = rho_global.slice(s![rho_slice.clone()]);
match tier {
PenaltyTier::Psi if matches!(penalty, AnalyticPenaltyKind::NuclearNorm(_)) => {
self.add_factored_beta_penalty_curvature_for_penalty(
&mut hbb_c,
penalty,
target_beta.view(),
rho_local,
penalty_scale,
projection,
);
}
PenaltyTier::Beta => {
self.add_factored_beta_penalty_curvature_for_penalty(
&mut hbb_c,
penalty,
target_beta.view(),
rho_local,
penalty_scale,
projection,
);
}
_ => {}
}
}
hbb_c
}
fn add_factored_beta_penalty_curvature_for_penalty(
&self,
hbb_c: &mut Array2<f64>,
penalty: &AnalyticPenaltyKind,
target_beta: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
penalty_scale: f64,
projection: &FrameProjection,
) {
let p = self.output_dim();
if let AnalyticPenaltyKind::DecoderIncoherence(base) = penalty {
let Some(per_fit) = self.live_decoder_incoherence_penalty(base) else {
return;
};
let beta_dim = self.beta_dim();
let mut probe = Array1::<f64>::zeros(beta_dim);
for k in 0..self.atoms.len() {
for basis_col in 0..projection.basis_sizes[k] {
for frame_col in 0..projection.ranks[k] {
probe.fill(0.0);
projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
let col = projection.border_offsets[k]
+ basis_col * projection.ranks[k]
+ frame_col;
let hv = per_fit.psd_majorizer_hvp(target_beta, rho_local, probe.view());
projection
.project_border_vec(hv.view())
.iter()
.enumerate()
.for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
}
}
}
return;
}
if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
let atom_idx = projection
.beta_offsets
.iter()
.position(|&offset| offset == start)
.expect("live mechanism-sparsity offset must match an SAE atom");
let block_len = end - start;
let mut local_penalty = per_atom.clone();
local_penalty.target = PsiSlice {
range: 0..block_len,
latent_dim: Some(projection.basis_sizes[atom_idx]),
};
let block = target_beta.slice(s![start..end]);
let mut probe = Array1::<f64>::zeros(block_len);
for basis_col in 0..projection.basis_sizes[atom_idx] {
for frame_col in 0..projection.ranks[atom_idx] {
probe.fill(0.0);
projection.lift_local_axis_into(&mut probe, atom_idx, basis_col, frame_col);
let col = projection.border_offsets[atom_idx]
+ basis_col * projection.ranks[atom_idx]
+ frame_col;
let hv = local_penalty.psd_majorizer_hvp(block, rho_local, probe.view());
projection.project_local_atom_vec_into(
atom_idx,
hv.view(),
hbb_c.column_mut(col),
penalty_scale,
);
}
}
}
return;
}
if let AnalyticPenaltyKind::NuclearNorm(base) = penalty {
for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
let atom_idx = projection
.beta_offsets
.iter()
.position(|&offset| offset == start)
.expect("live nuclear-norm offset must match an SAE atom");
let block = target_beta.slice(s![start..end]);
let block_len = end - start;
let mut probe = Array1::<f64>::zeros(block_len);
for basis_col in 0..projection.basis_sizes[atom_idx] {
for frame_col in 0..projection.ranks[atom_idx] {
probe.fill(0.0);
projection.lift_local_axis_into(&mut probe, atom_idx, basis_col, frame_col);
let col = projection.border_offsets[atom_idx]
+ basis_col * projection.ranks[atom_idx]
+ frame_col;
let hv = per_atom.psd_majorizer_hvp(block, rho_local, probe.view());
projection.project_local_atom_vec_into(
atom_idx,
hv.view(),
hbb_c.column_mut(col),
penalty_scale,
);
}
}
}
return;
}
let beta_dim = self.beta_dim();
let mut probe = Array1::<f64>::zeros(beta_dim);
for k in 0..self.atoms.len() {
for basis_col in 0..projection.basis_sizes[k] {
for frame_col in 0..projection.ranks[k] {
probe.fill(0.0);
projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
let col =
projection.border_offsets[k] + basis_col * projection.ranks[k] + frame_col;
let hv = penalty.psd_majorizer_hvp(target_beta, rho_local, probe.view());
projection
.project_border_vec(hv.view())
.iter()
.enumerate()
.for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
}
}
}
assert_eq!(p, self.output_dim());
}
fn ext_coord_matrix(&self) -> Array2<f64> {
let n = self.n_obs();
let q = self.assignment.row_block_dim();
let flat = self.assignment.flatten_ext_coords();
let mut out = Array2::<f64>::zeros((n, q));
for row in 0..n {
for col in 0..q {
out[[row, col]] = flat[row * q + col];
}
}
out
}
fn ext_coord_manifold(&self) -> LatentManifold {
let mut parts = Vec::with_capacity(self.assignment.row_block_dim());
for _ in 0..self.assignment.assignment_coord_dim() {
parts.push(LatentManifold::Euclidean);
}
let mut any_constrained = false;
for coord in &self.assignment.coords {
if coord.manifold().is_euclidean() {
for _ in 0..coord.latent_dim() {
parts.push(LatentManifold::Euclidean);
}
} else {
any_constrained = true;
parts.push(coord.manifold().clone());
}
}
if any_constrained {
LatentManifold::Product(parts)
} else {
LatentManifold::Euclidean
}
}
fn apply_sae_riemannian_geometry(&self, sys: &mut ArrowSchurSystem) {
let manifold = self.ext_coord_manifold();
if manifold.is_euclidean() {
return;
}
let ext = self.ext_coord_matrix();
let latent =
LatentCoordValues::from_matrix_with_manifold(ext.view(), LatentIdMode::None, manifold);
sys.apply_riemannian_latent_geometry(&latent);
}
fn symmetric_rank(s: &Array2<f64>) -> Result<usize, String> {
let m = s.ncols();
if m == 0 {
return Ok(0);
}
let mut sym = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
sym[[i, j]] = 0.5 * (s[[i, j]] + s[[j, i]]);
}
}
let (evals, _evecs) = sym
.eigh(Side::Lower)
.map_err(|e| format!("SaeManifoldTerm::symmetric_rank: eigh failed: {e}"))?;
let max_eig = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v));
if !(max_eig > 0.0) {
return Ok(0);
}
let tol = SAE_MANIFOLD_SPECTRAL_RANK_CUTOFF * max_eig;
Ok(evals.iter().filter(|&&v| v > tol).count())
}
pub fn reml_criterion(
&mut self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
registry: Option<&AnalyticPenaltyRegistry>,
inner_max_iter: usize,
learning_rate: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
) -> Result<(f64, SaeManifoldLoss), String> {
if self.streaming_plan().streaming {
self.reml_criterion_streaming_exact(
target,
rho,
registry,
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
)
} else {
let (v, loss, _cache) = self.reml_criterion_with_cache(
target,
rho,
registry,
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
)?;
Ok((v, loss))
}
}
pub fn reml_criterion_with_cache(
&mut self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
registry: Option<&AnalyticPenaltyRegistry>,
inner_max_iter: usize,
learning_rate: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
) -> Result<(f64, SaeManifoldLoss, ArrowFactorCache), String> {
let mut rho_fixed = rho.clone();
let mut loss = self.run_joint_fit_arrow_schur(
target,
&mut rho_fixed,
registry,
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
)?;
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let cache = self.converge_inner_for_undamped_logdet(
target,
rho,
&mut rho_fixed,
registry,
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
&mut loss,
&options,
)?;
let log_det = arrow_log_det_from_cache(&cache).ok_or_else(|| {
"SaeManifoldTerm::reml_criterion: arrow_log_det_from_cache returned None at \
ridge=0 Direct mode (no dense Schur factor); the joint Hessian log-det is \
required for the Laplace normaliser"
.to_string()
})?;
let occam = self.reml_occam_term(rho)?;
let extra_penalty_energy = match registry {
Some(reg) => self
.reml_extra_penalty_value_total(reg)
.map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?,
None => 0.0,
};
let v = loss.total() + extra_penalty_energy + 0.5 * log_det - occam;
Ok((v, loss, cache))
}
fn is_undamped_evidence_row_non_pd(err: &ArrowSchurError) -> bool {
matches!(
err,
ArrowSchurError::PerRowFactorFailed { reason, .. }
if reason.contains("H_tt is non-PD at base ridge")
&& reason.contains("evidence mode preserves the genuine Cholesky")
)
}
fn converge_inner_for_undamped_logdet(
&mut self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
rho_fixed: &mut SaeManifoldRho,
registry: Option<&AnalyticPenaltyRegistry>,
inner_max_iter: usize,
learning_rate: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
loss: &mut SaeManifoldLoss,
options: &ArrowSolveOptions,
) -> Result<ArrowFactorCache, String> {
if inner_max_iter == 0 {
let sys = self
.assemble_arrow_schur(target, rho, registry)
.map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
let factored = solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, options)
.map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
return Ok(factored.2);
}
let mut total_inner_iter = inner_max_iter;
let max_refine_iter = inner_max_iter.max(1).saturating_mul(16).max(64);
loop {
let sys = self
.assemble_arrow_schur(target, rho, registry)
.map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
let grad_norm_sq: f64 = sys
.rows
.iter()
.map(|row| row.gt.iter().map(|&v| v * v).sum::<f64>())
.sum::<f64>()
+ sys.gb.iter().map(|&v| v * v).sum::<f64>();
let mut iterate_norm_sq = 0.0_f64;
for &v in self.assignment.logits.iter() {
iterate_norm_sq += v * v;
}
for coords in &self.assignment.coords {
let matrix = coords.as_matrix();
for &v in matrix.iter() {
iterate_norm_sq += v * v;
}
}
for atom in &self.atoms {
for &v in atom.decoder_coefficients.iter() {
iterate_norm_sq += v * v;
}
}
let grad_norm = grad_norm_sq.sqrt();
let iterate_scale = 1.0 + iterate_norm_sq.sqrt();
let step_tolerance = SAE_MANIFOLD_INNER_STEP_REL_TOL * iterate_scale;
let grad_tolerance = SAE_MANIFOLD_INNER_GRAD_REL_TOL * iterate_scale;
if !grad_norm_sq.is_finite() {
return Err(format!(
"SaeManifoldTerm::reml_criterion: undamped inner KKT residual is non-finite \
at the inner optimum (‖g‖²={grad_norm_sq}); the joint Hessian \
factorisation is degenerate at this ρ"
));
}
let (delta_t, delta_beta, cache): (Array1<f64>, Array1<f64>, ArrowFactorCache) =
match solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, options) {
Ok(factored) => factored,
Err(err) if Self::is_undamped_evidence_row_non_pd(&err) => {
if grad_norm <= grad_tolerance {
return Err(format!(
"SaeManifoldTerm::reml_criterion: stationary undamped evidence \
factorization still has a non-PD per-row H_tt block \
(‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}); {err}"
));
}
if total_inner_iter >= max_refine_iter {
return Err(format!(
"SaeManifoldTerm::reml_criterion: undamped evidence \
factorization hit a non-PD per-row H_tt block before KKT \
stationarity (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}) \
and the refinement budget was exhausted after \
{total_inner_iter} inner iterations; {err}"
));
}
let remaining = max_refine_iter - total_inner_iter;
let refine_iter = inner_max_iter.max(1).min(remaining);
*loss = self.run_joint_fit_arrow_schur(
target,
rho_fixed,
registry,
refine_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
)?;
total_inner_iter += refine_iter;
continue;
}
Err(err) => {
return Err(format!("SaeManifoldTerm::reml_criterion: {err}"));
}
};
let step_norm_sq: f64 = delta_t.iter().map(|&v| v * v).sum::<f64>()
+ delta_beta.iter().map(|&v| v * v).sum::<f64>();
if !step_norm_sq.is_finite() {
return Err(format!(
"SaeManifoldTerm::reml_criterion: undamped inner residual is non-finite at \
the inner optimum (‖Δ‖²={step_norm_sq}, ‖g‖²={grad_norm_sq}); the joint \
Hessian factorisation is degenerate at this ρ"
));
}
let step_norm = step_norm_sq.sqrt();
let quotient_step_norm_sq =
self.quotient_newton_step_norm_sq(delta_t.view(), delta_beta.view(), step_norm_sq)?;
let quotient_step_norm = quotient_step_norm_sq.sqrt();
if grad_norm <= grad_tolerance || quotient_step_norm <= step_tolerance {
return Ok(cache);
}
if total_inner_iter >= max_refine_iter {
return Err(format!(
"SaeManifoldTerm::reml_criterion: inner solve did not converge at fixed ρ; \
neither the KKT gradient ‖g‖={grad_norm:.6e} (tol {grad_tolerance:.6e}) nor \
the quotient Newton step ‖Π⊥gauge Δ‖={quotient_step_norm:.6e} \
(raw ‖Δ‖={step_norm:.6e}, tol {step_tolerance:.6e}) met \
tolerance after {total_inner_iter} inner iterations. Refusing to rank an \
off-optimum Laplace criterion."
));
}
let remaining = max_refine_iter - total_inner_iter;
let refine_iter = inner_max_iter.max(1).min(remaining);
*loss = self.run_joint_fit_arrow_schur(
target,
rho_fixed,
registry,
refine_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
)?;
total_inner_iter += refine_iter;
}
}
fn reml_occam_term(&self, rho: &SaeManifoldRho) -> Result<f64, String> {
let mut penalized_channel_dim = 0usize;
for atom in &self.atoms {
let rank_s = Self::symmetric_rank(&atom.smooth_penalty)?;
penalized_channel_dim += atom.border_frame_rank() * rank_s;
}
let grassmann_dim = self.grassmann_evidence_dimension();
let occam_penalty = 0.5 * (penalized_channel_dim as f64) * rho.log_lambda_smooth;
let frame_dim_term = 0.5 * (grassmann_dim as f64) * rho.log_lambda_smooth;
Ok(occam_penalty - frame_dim_term)
}
fn reml_occam_log_lambda_smooth_derivative(&self) -> Result<f64, String> {
let mut penalized_channel_dim = 0usize;
for atom in &self.atoms {
let rank_s = Self::symmetric_rank(&atom.smooth_penalty)?;
penalized_channel_dim += atom.border_frame_rank() * rank_s;
}
let grassmann_dim = self.grassmann_evidence_dimension();
Ok(0.5 * ((penalized_channel_dim as f64) - (grassmann_dim as f64)))
}
pub fn reml_criterion_streaming_exact(
&mut self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
registry: Option<&AnalyticPenaltyRegistry>,
inner_max_iter: usize,
learning_rate: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
) -> Result<(f64, SaeManifoldLoss), String> {
let mut rho_fixed = rho.clone();
let mut loss = self.run_joint_fit_arrow_schur(
target,
&mut rho_fixed,
registry,
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
)?;
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let converged_cache = self.converge_inner_for_undamped_logdet(
target,
rho,
&mut rho_fixed,
registry,
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
&mut loss,
&options,
)?;
drop(converged_cache);
let log_det = self.streaming_exact_arrow_log_det(target, rho, registry)?;
let occam = self.reml_occam_term(rho)?;
let extra_penalty_energy = match registry {
Some(reg) => self
.reml_extra_penalty_value_total(reg)
.map_err(|err| format!("SaeManifoldTerm::reml_criterion_streaming_exact: {err}"))?,
None => 0.0,
};
Ok((
loss.total() + extra_penalty_energy + 0.5 * log_det - occam,
loss,
))
}
pub fn streaming_exact_arrow_log_det(
&mut self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
registry: Option<&AnalyticPenaltyRegistry>,
) -> Result<f64, String> {
if target.dim() != (self.n_obs(), self.output_dim()) {
return Err(format!(
"SaeManifoldTerm::streaming_exact_arrow_log_det: target must be ({}, {}); got {:?}",
self.n_obs(),
self.output_dim(),
target.dim()
));
}
let n_total = self.n_obs();
let chunk_size = self.streaming_plan().chunk_size.min(n_total.max(1));
let border_dim = if self.frames_active() {
self.factored_border_dim()
} else {
self.beta_dim()
};
let mut schur_acc = Array2::<f64>::zeros((border_dim, border_dim));
let mut log_det_tt = 0.0_f64;
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let mut start = 0usize;
while start < n_total {
let end = (start + chunk_size).min(n_total);
let penalty_scale = (end - start) as f64 / n_total as f64;
let chunk_logits = self.assignment.logits.slice(s![start..end, ..]).to_owned();
let chunk_coords: Vec<Array2<f64>> = self
.assignment
.coords
.iter()
.map(|coord| coord.as_matrix().slice(s![start..end, ..]).to_owned())
.collect();
let mut chunk = self.materialize_chunk(chunk_logits, chunk_coords)?;
if let Some(w) = self.row_loss_weights.as_deref() {
chunk.row_loss_weights = Some(w[start..end].to_vec());
}
let z_chunk = target.slice(s![start..end, ..]);
let sys = chunk
.assemble_arrow_schur_scaled(z_chunk, rho, registry, penalty_scale)
.map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
let mut streaming = StreamingArrowSchur::from_system(&sys, sys.rows.len().max(1));
let (chunk_log_det_tt, chunk_schur) = streaming
.reduced_schur_and_log_det_tt(0.0, 0.0, &options)
.map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
log_det_tt += chunk_log_det_tt;
for row in 0..border_dim {
for col in 0..border_dim {
schur_acc[[row, col]] += chunk_schur[[row, col]];
}
}
start = end;
}
let log_det_schur = StreamingArrowSchur::reduced_schur_log_det(&schur_acc, &options)
.map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
Ok(log_det_tt + log_det_schur)
}
fn ard_coord_sumsq(&self) -> Vec<Array1<f64>> {
let mut out = Vec::with_capacity(self.k_atoms());
for coord in &self.assignment.coords {
let d = coord.latent_dim();
let periods = coord.effective_axis_periods();
let mut sq = Array1::<f64>::zeros(d);
for row in 0..coord.n_obs() {
let t = coord.row(row);
for axis in 0..d {
sq[axis] += ArdAxisPrior::eval(1.0, t[axis], periods[axis]).sq_equiv;
}
}
out.push(sq);
}
out
}
fn ard_inverse_traces(
&self,
cache: &ArrowFactorCache,
) -> Result<Vec<Array1<f64>>, ArrowSchurError> {
let inv_diag = cache.latent_block_inverse_diagonal()?;
let n = self.n_obs();
let coord_offsets = self.assignment.coord_offsets();
let mut traces: Vec<Array1<f64>> = self
.assignment
.coords
.iter()
.map(|c| Array1::<f64>::zeros(c.latent_dim()))
.collect();
for row in 0..n {
let row_base = cache.row_offsets[row];
match self.last_row_layout {
Some(ref layout) => {
let active = &layout.active_atoms[row];
let starts = &layout.coord_starts[row];
for (pos, &k) in active.iter().enumerate() {
let d = self.assignment.coords[k].latent_dim();
let block_start = starts[pos];
for axis in 0..d {
traces[k][axis] += inv_diag[row_base + block_start + axis];
}
}
}
None => {
for k in 0..self.k_atoms() {
let d = self.assignment.coords[k].latent_dim();
let block_start = coord_offsets[k];
for axis in 0..d {
traces[k][axis] += inv_diag[row_base + block_start + axis];
}
}
}
}
}
Ok(traces)
}
fn ard_log_precision_explicit_derivatives(
&self,
rho: &SaeManifoldRho,
) -> Result<Vec<Array1<f64>>, String> {
if rho.log_ard.len() != self.k_atoms() {
return Err(format!(
"ARD rho has {} atoms but term has {}",
rho.log_ard.len(),
self.k_atoms()
));
}
let n = self.n_obs() as f64;
let mut out = Vec::with_capacity(self.k_atoms());
for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
let d = coord.latent_dim();
let mut atom_out = Array1::<f64>::zeros(rho.log_ard[atom_idx].len());
if rho.log_ard[atom_idx].is_empty() {
out.push(atom_out);
continue;
}
if rho.log_ard[atom_idx].len() != d {
return Err(format!(
"ARD rho atom {atom_idx} has len {} but atom dim is {d}",
rho.log_ard[atom_idx].len()
));
}
let periods = coord.effective_axis_periods();
for axis in 0..d {
let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom_idx][axis]);
let period = periods[axis];
let mut energy_deriv = 0.0_f64;
for row in 0..coord.n_obs() {
let t = coord.row(row)[axis];
energy_deriv += ArdAxisPrior::eval(alpha, t, period).value;
}
let normalizer_deriv = match period {
None => -0.5 * n,
Some(p) => {
let kappa = std::f64::consts::TAU / p;
let eta = alpha / (kappa * kappa);
let i0 = bessel_i0(eta);
let i1 = bessel_i1(eta);
n * eta * (-1.0 + i1 / i0)
}
};
atom_out[axis] = energy_deriv + normalizer_deriv;
}
out.push(atom_out);
}
Ok(out)
}
fn ard_log_precision_hessian_trace(
&self,
rho: &SaeManifoldRho,
cache: &ArrowFactorCache,
) -> Result<Vec<Array1<f64>>, ArrowSchurError> {
let inv_diag = cache.latent_block_inverse_diagonal()?;
let n = self.n_obs();
let coord_offsets = self.assignment.coord_offsets();
let ard_axis_periods: Vec<Vec<Option<f64>>> = self
.assignment
.coords
.iter()
.map(LatentCoordValues::effective_axis_periods)
.collect();
let mut traces: Vec<Array1<f64>> = self
.assignment
.coords
.iter()
.enumerate()
.map(|(k, c)| {
if rho.log_ard[k].is_empty() {
Array1::<f64>::zeros(0)
} else {
Array1::<f64>::zeros(c.latent_dim())
}
})
.collect();
for row in 0..n {
let row_base = cache.row_offsets[row];
match self.last_row_layout {
Some(ref layout) => {
let active = &layout.active_atoms[row];
let starts = &layout.coord_starts[row];
for (pos, &k) in active.iter().enumerate() {
if rho.log_ard[k].is_empty() {
continue;
}
let coord = &self.assignment.coords[k];
let d = coord.latent_dim();
let block_start = starts[pos];
for axis in 0..d {
let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
let t = coord.row(row)[axis];
let prior = ArdAxisPrior::eval(alpha, t, ard_axis_periods[k][axis]);
traces[k][axis] +=
0.5 * inv_diag[row_base + block_start + axis] * prior.hess.max(0.0);
}
}
}
None => {
for k in 0..self.k_atoms() {
if rho.log_ard[k].is_empty() {
continue;
}
let coord = &self.assignment.coords[k];
let d = coord.latent_dim();
let block_start = coord_offsets[k];
for axis in 0..d {
let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
let t = coord.row(row)[axis];
let prior = ArdAxisPrior::eval(alpha, t, ard_axis_periods[k][axis]);
traces[k][axis] +=
0.5 * inv_diag[row_base + block_start + axis] * prior.hess.max(0.0);
}
}
}
}
}
Ok(traces)
}
fn decoder_smoothness_quadratic_form(&self) -> f64 {
let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
.atoms
.iter()
.map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
.collect();
let sb_all = batched_smooth_sb(&sb_inputs, true);
let mut acc = 0.0_f64;
for (atom, sb) in self.atoms.iter().zip(sb_all.iter()) {
acc += (&atom.decoder_coefficients * sb).sum();
}
acc
}
fn decoder_smoothness_effective_dof(
&self,
cache: &ArrowFactorCache,
lambda_smooth: f64,
) -> Result<f64, ArrowSchurError> {
let p = self.output_dim();
let frames_active = self.frames_active();
let (offsets, out_dim): (Vec<usize>, Box<dyn Fn(usize) -> usize>) = if frames_active {
let ranks: Vec<usize> = self.atoms.iter().map(|a| a.border_frame_rank()).collect();
(
self.factored_beta_offsets(),
Box::new(move |k: usize| ranks[k]),
)
} else {
(self.beta_offsets(), Box::new(move |_k: usize| p))
};
let k = cache.k;
let mut trace = 0.0_f64;
let mut m_col = Array1::<f64>::zeros(k);
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let s = &atom.smooth_penalty;
let m = atom.basis_size();
let off = offsets[atom_idx];
let r = out_dim(atom_idx);
for mu in 0..m {
for oc in 0..r {
let col = off + mu * r + oc;
m_col.fill(0.0);
for nu in 0..m {
let s_nu_mu = 0.5 * (s[[nu, mu]] + s[[mu, nu]]);
m_col[off + nu * r + oc] = lambda_smooth * s_nu_mu;
}
let z = cache.schur_inverse_apply(m_col.view())?;
trace += z[col];
}
}
}
Ok(trace)
}
fn assignment_log_strength_hessian_trace(
&self,
rho: &SaeManifoldRho,
cache: &ArrowFactorCache,
) -> Result<f64, String> {
let hdiag = assignment_prior_log_strength_hdiag(&self.assignment, rho)?;
if hdiag.is_empty() {
return Ok(0.0);
}
let inv_diag = cache
.latent_block_inverse_diagonal()
.map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?;
let k_atoms = self.k_atoms();
let assignment_dim = self.assignment.assignment_coord_dim();
let mut trace = 0.0_f64;
for row in 0..self.n_obs() {
let row_base = cache.row_offsets[row];
let assignment_base = row * k_atoms;
match self.last_row_layout {
Some(ref layout) => {
for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
trace += inv_diag[row_base + pos] * hdiag[assignment_base + atom];
}
}
None => {
for free_idx in 0..assignment_dim {
trace += inv_diag[row_base + free_idx] * hdiag[assignment_base + free_idx];
}
}
}
}
Ok(0.5 * trace)
}
fn border_channels_for_cache(
&self,
cache: &ArrowFactorCache,
) -> Result<Vec<SaeBorderChannel>, String> {
let p = self.output_dim();
let frames_active = self.last_frames_active && cache.k == self.factored_border_dim();
let offsets = if frames_active {
self.factored_beta_offsets()
} else {
self.beta_offsets()
};
let mut channels = Vec::with_capacity(cache.k);
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
let frame = if frames_active {
self.frame_output_matrix(atom_idx)
} else {
Array2::<f64>::eye(p)
};
let r = frame.ncols();
for basis_col in 0..m {
for channel in 0..r {
let mut output = vec![0.0_f64; p];
for out_col in 0..p {
output[out_col] = frame[[out_col, channel]];
}
channels.push(SaeBorderChannel {
atom: atom_idx,
basis_col,
index: offsets[atom_idx] + basis_col * r + channel,
output,
});
}
}
}
if channels.len() != cache.k {
return Err(format!(
"border channel layout has {} entries but cache border has {}",
channels.len(),
cache.k
));
}
Ok(channels)
}
fn row_vars_for_cache_row(
&self,
row: usize,
cache: &ArrowFactorCache,
) -> Result<Vec<SaeLocalRowVar>, String> {
let q_row = cache.row_dims[row];
let mut vars: Vec<Option<SaeLocalRowVar>> = vec![None; q_row];
match self.last_row_layout {
Some(ref layout) => {
for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
vars[pos] = Some(SaeLocalRowVar::Logit { atom });
let start = layout.coord_starts[row][pos];
let d = self.assignment.coords[atom].latent_dim();
for axis in 0..d {
vars[start + axis] = Some(SaeLocalRowVar::Coord { atom, axis });
}
}
}
None => {
let assignment_dim = self.assignment.assignment_coord_dim();
let coord_offsets = self.assignment.coord_offsets();
for atom in 0..assignment_dim {
vars[atom] = Some(SaeLocalRowVar::Logit { atom });
}
for atom in 0..self.k_atoms() {
let start = coord_offsets[atom];
let d = self.assignment.coords[atom].latent_dim();
for axis in 0..d {
vars[start + axis] = Some(SaeLocalRowVar::Coord { atom, axis });
}
}
}
}
vars.into_iter()
.enumerate()
.map(|(idx, v)| {
v.ok_or_else(|| {
format!("row_vars_for_cache_row: row {row} position {idx} was not mapped")
})
})
.collect()
}
fn atom_second_jets(&self) -> Result<Vec<Array4<f64>>, String> {
let mut out = Vec::with_capacity(self.k_atoms());
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let coords = self.assignment.coords[atom_idx].as_matrix();
let jet = if let Some(second) = atom.basis_second_jet.as_ref() {
second.second_jet(coords.view())?
} else {
let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
format!(
"logdet_theta_adjoint: atom '{}' has no basis evaluator for second jets",
atom.name
)
})?;
evaluator
.second_jet_dyn(coords.view())
.ok_or_else(|| {
format!(
"logdet_theta_adjoint: atom '{}' basis does not expose analytic second jets",
atom.name
)
})??
};
let expected = (
atom.n_obs(),
atom.basis_size(),
atom.latent_dim,
atom.latent_dim,
);
if jet.dim() != expected {
return Err(format!(
"logdet_theta_adjoint: atom '{}' second jet shape {:?}, expected {:?}",
atom.name,
jet.dim(),
expected
));
}
out.push(jet);
}
Ok(out)
}
fn gate_derivatives_for_row(
&self,
row: usize,
assignments: ArrayView1<'_, f64>,
vars: &[SaeLocalRowVar],
) -> Result<(Vec<Vec<f64>>, Vec<Vec<Vec<f64>>>), String> {
let k_atoms = self.k_atoms();
let q = vars.len();
let mut dz = vec![vec![0.0_f64; k_atoms]; q];
let mut d2z = vec![vec![vec![0.0_f64; k_atoms]; q]; q];
match self.assignment.mode {
AssignmentMode::Softmax { temperature, .. } => {
let inv_tau = 1.0 / temperature;
for (a_idx, var_a) in vars.iter().enumerate() {
let SaeLocalRowVar::Logit { atom: j } = *var_a else {
continue;
};
for k in 0..k_atoms {
let indicator = if k == j { 1.0 } else { 0.0 };
dz[a_idx][k] = assignments[k] * (indicator - assignments[j]) * inv_tau;
}
}
for (a_idx, var_a) in vars.iter().enumerate() {
let SaeLocalRowVar::Logit { atom: j } = *var_a else {
continue;
};
for (b_idx, var_b) in vars.iter().enumerate() {
let SaeLocalRowVar::Logit { atom: l } = *var_b else {
continue;
};
for k in 0..k_atoms {
let ikl = if k == l { 1.0 } else { 0.0 };
let ikj = if k == j { 1.0 } else { 0.0 };
let ijl = if j == l { 1.0 } else { 0.0 };
d2z[a_idx][b_idx][k] = assignments[k]
* ((ikl - assignments[l]) * (ikj - assignments[j])
- assignments[j] * (ijl - assignments[l]))
* inv_tau
* inv_tau;
}
}
}
}
AssignmentMode::IBPMap {
temperature, alpha, ..
} => {
let prior = ibp_stick_breaking_prior(k_atoms, alpha);
let inv_tau = 1.0 / temperature;
for (idx, var) in vars.iter().enumerate() {
let SaeLocalRowVar::Logit { atom } = *var else {
continue;
};
let (_z, d1, d2) =
sae_sigmoid_derivatives_from_value(assignments[atom], inv_tau, prior[atom]);
dz[idx][atom] = d1;
d2z[idx][idx][atom] = d2;
}
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let inv_tau = 1.0 / temperature;
let logits = self.assignment.logits.row(row);
for (idx, var) in vars.iter().enumerate() {
let SaeLocalRowVar::Logit { atom } = *var else {
continue;
};
if logits[atom] <= threshold {
continue;
}
let (_z, d1, d2) =
sae_sigmoid_derivatives_from_value(assignments[atom], inv_tau, 1.0);
dz[idx][atom] = d1;
d2z[idx][idx][atom] = d2;
}
}
}
Ok((dz, d2z))
}
fn decoded_second_row(
atom: &SaeManifoldAtom,
second_jet: &Array4<f64>,
row: usize,
axis_a: usize,
axis_b: usize,
out: &mut [f64],
) {
out.fill(0.0);
for basis_col in 0..atom.basis_size() {
let d2phi = second_jet[[row, basis_col, axis_a, axis_b]];
if d2phi == 0.0 {
continue;
}
for out_col in 0..atom.output_dim() {
out[out_col] += d2phi * atom.decoder_coefficients[[basis_col, out_col]];
}
}
}
fn row_jets_for_logdet(
&self,
row: usize,
vars: Vec<SaeLocalRowVar>,
assignments: ArrayView1<'_, f64>,
second_jets: &[Array4<f64>],
border: &[SaeBorderChannel],
) -> Result<SaeRowJets, String> {
let p = self.output_dim();
let q = vars.len();
let k_atoms = self.k_atoms();
let sqrt_row_w = self
.row_loss_weights
.as_deref()
.map_or(1.0, |w| w[row].sqrt());
let (dz, d2z) = self.gate_derivatives_for_row(row, assignments, &vars)?;
let mut decoded = vec![vec![0.0_f64; p]; k_atoms];
let mut d1: Vec<Vec<Vec<f64>>> = self
.atoms
.iter()
.map(|atom| vec![vec![0.0_f64; p]; atom.latent_dim])
.collect();
let mut d2: Vec<Vec<Vec<Vec<f64>>>> = self
.atoms
.iter()
.map(|atom| vec![vec![vec![0.0_f64; p]; atom.latent_dim]; atom.latent_dim])
.collect();
let mut scratch = vec![0.0_f64; p];
for k in 0..k_atoms {
self.atoms[k].fill_decoded_row(row, &mut decoded[k]);
for axis in 0..self.atoms[k].latent_dim {
self.atoms[k].fill_decoded_derivative_row(row, axis, &mut d1[k][axis]);
}
for axis_a in 0..self.atoms[k].latent_dim {
for axis_b in 0..self.atoms[k].latent_dim {
Self::decoded_second_row(
&self.atoms[k],
&second_jets[k],
row,
axis_a,
axis_b,
&mut scratch,
);
d2[k][axis_a][axis_b].clone_from_slice(&scratch);
}
}
}
let mut first = vec![vec![0.0_f64; p]; q];
for (idx, var) in vars.iter().enumerate() {
match *var {
SaeLocalRowVar::Logit { .. } => {
for k in 0..k_atoms {
let coeff = dz[idx][k] * sqrt_row_w;
if coeff == 0.0 {
continue;
}
for out_col in 0..p {
first[idx][out_col] += coeff * decoded[k][out_col];
}
}
}
SaeLocalRowVar::Coord { atom, axis } => {
let coeff = assignments[atom] * sqrt_row_w;
for out_col in 0..p {
first[idx][out_col] = coeff * d1[atom][axis][out_col];
}
}
}
}
let mut second = vec![vec![vec![0.0_f64; p]; q]; q];
for a in 0..q {
for b in 0..q {
match (vars[a], vars[b]) {
(SaeLocalRowVar::Logit { .. }, SaeLocalRowVar::Logit { .. }) => {
for k in 0..k_atoms {
let coeff = d2z[a][b][k] * sqrt_row_w;
if coeff == 0.0 {
continue;
}
for out_col in 0..p {
second[a][b][out_col] += coeff * decoded[k][out_col];
}
}
}
(SaeLocalRowVar::Logit { .. }, SaeLocalRowVar::Coord { atom, axis }) => {
let coeff = dz[a][atom] * sqrt_row_w;
for out_col in 0..p {
second[a][b][out_col] = coeff * d1[atom][axis][out_col];
}
}
(SaeLocalRowVar::Coord { atom, axis }, SaeLocalRowVar::Logit { .. }) => {
let coeff = dz[b][atom] * sqrt_row_w;
for out_col in 0..p {
second[a][b][out_col] = coeff * d1[atom][axis][out_col];
}
}
(
SaeLocalRowVar::Coord {
atom: atom_a,
axis: axis_a,
},
SaeLocalRowVar::Coord {
atom: atom_b,
axis: axis_b,
},
) if atom_a == atom_b => {
let coeff = assignments[atom_a] * sqrt_row_w;
for out_col in 0..p {
second[a][b][out_col] = coeff * d2[atom_a][axis_a][axis_b][out_col];
}
}
_ => {}
}
}
}
let mut beta = vec![vec![0.0_f64; p]; border.len()];
let mut beta_deriv = vec![vec![vec![0.0_f64; p]; border.len()]; q];
let mut beta_l_deriv = vec![vec![vec![0.0_f64; p]; border.len()]; q];
for (beta_pos, channel) in border.iter().enumerate() {
let atom = channel.atom;
let phi = self.atoms[atom].basis_values[[row, channel.basis_col]];
let base = assignments[atom] * phi * sqrt_row_w;
for out_col in 0..p {
beta[beta_pos][out_col] = base * channel.output[out_col];
}
for (var_idx, var) in vars.iter().enumerate() {
let scalar = match *var {
SaeLocalRowVar::Logit { .. } => dz[var_idx][atom] * phi * sqrt_row_w,
SaeLocalRowVar::Coord {
atom: coord_atom,
axis,
} if coord_atom == atom => {
assignments[atom]
* self.atoms[atom].basis_jacobian[[row, channel.basis_col, axis]]
* sqrt_row_w
}
_ => 0.0,
};
if scalar != 0.0 {
for out_col in 0..p {
beta_deriv[var_idx][beta_pos][out_col] = scalar * channel.output[out_col];
}
}
let scalar_l = match *var {
SaeLocalRowVar::Logit { .. } => {
dz[var_idx][atom]
* self.atoms[atom].basis_values[[row, channel.basis_col]]
* sqrt_row_w
}
SaeLocalRowVar::Coord {
atom: coord_atom,
axis,
} if coord_atom == atom => {
assignments[atom]
* self.atoms[atom].basis_jacobian[[row, channel.basis_col, axis]]
* sqrt_row_w
}
_ => 0.0,
};
if scalar_l != 0.0 {
for out_col in 0..p {
beta_l_deriv[var_idx][beta_pos][out_col] =
scalar_l * channel.output[out_col];
}
}
}
}
Ok(SaeRowJets {
vars,
first,
second,
beta,
beta_deriv,
beta_l_deriv,
})
}
fn assignment_prior_hdiag_derivative_entry(
&self,
rho: &SaeManifoldRho,
row: usize,
diag_atom: usize,
wrt: SaeLocalRowVar,
ibp_channels: Option<&IbpHessianDiagThirdChannels>,
) -> f64 {
let SaeLocalRowVar::Logit { atom: wrt_atom } = wrt else {
return 0.0;
};
match self.assignment.mode {
AssignmentMode::Softmax {
temperature,
sparsity,
} => {
let assignments = self.assignment.assignments_row(row);
let inv_tau = 1.0 / temperature;
let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
let k_atoms = assignments.len();
let mut l = vec![0.0_f64; k_atoms];
let mut mean = 0.0_f64;
for k in 0..k_atoms {
l[k] = assignments[k].max(1.0e-300).ln() + 1.0;
mean += assignments[k] * l[k];
}
let mut da = vec![0.0_f64; k_atoms];
for k in 0..k_atoms {
let indicator = if k == wrt_atom { 1.0 } else { 0.0 };
da[k] = assignments[k] * (indicator - assignments[wrt_atom]) * inv_tau;
}
let dmean: f64 = (0..k_atoms).map(|k| da[k] * l[k]).sum();
let k = diag_atom;
let term = (1.0 - 2.0 * assignments[k]) * (mean - l[k]) + assignments[k] - 1.0;
let dl_k = da[k] / assignments[k].max(1.0e-300);
let dterm = -2.0 * da[k] * (mean - l[k])
+ (1.0 - 2.0 * assignments[k]) * (dmean - dl_k)
+ da[k];
scale * (da[k] * term + assignments[k] * dterm)
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
if diag_atom != wrt_atom {
return 0.0;
}
let logit = self.assignment.logits[[row, diag_atom]];
if !jumprelu_in_optimization_band(logit, threshold, temperature) {
return 0.0;
}
let inv_tau = 1.0 / temperature;
let activation =
crate::linalg::utils::stable_logistic((logit - threshold) * inv_tau);
let slope = activation * (1.0 - activation);
2.0 * rho.lambda_sparse()
* slope
* slope
* (1.0 - 2.0 * activation)
* inv_tau
* inv_tau
* inv_tau
}
AssignmentMode::IBPMap { .. } => {
if diag_atom != wrt_atom {
return 0.0;
}
match ibp_channels {
Some(ch) => ch.local_logit_third[row * ch.k_max + diag_atom],
None => 0.0,
}
}
}
}
fn ard_majorized_hessian_derivative(
&self,
rho: &SaeManifoldRho,
row: usize,
atom: usize,
axis: usize,
) -> f64 {
if rho.log_ard[atom].is_empty() {
return 0.0;
}
let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
let periods = self.assignment.coords[atom].effective_axis_periods();
let t = self.assignment.coords[atom].row(row)[axis];
let prior = ArdAxisPrior::eval(alpha, t, periods[axis]);
if prior.hess <= 0.0 {
return 0.0;
}
match periods[axis] {
None => 0.0,
Some(period) => {
let kappa = std::f64::consts::TAU / period;
-alpha * kappa * (kappa * t).sin()
}
}
}
pub fn outer_rho_gradient_ift_rhs(
&self,
rho: &SaeManifoldRho,
j: usize,
cache: &ArrowFactorCache,
) -> Result<SaeArrowVector, String> {
let n_params = rho.to_flat().len();
if j >= n_params {
return Err(format!(
"outer_rho_gradient_ift_rhs: coordinate {j} outside rho dim {n_params}"
));
}
let mut t = Array1::<f64>::zeros(cache.delta_t_len());
let mut beta = Array1::<f64>::zeros(cache.k);
if j == 0 {
let (assignment_grad, _) = assignment_prior_grad_hdiag(&self.assignment, rho)?;
let k_atoms = self.k_atoms();
let assignment_dim = self.assignment.assignment_coord_dim();
for row in 0..self.n_obs() {
let base = cache.row_offsets[row];
let assignment_base = row * k_atoms;
match self.last_row_layout {
Some(ref layout) => {
for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
t[base + pos] = assignment_grad[assignment_base + atom];
}
}
None => {
for free_idx in 0..assignment_dim {
t[base + free_idx] = assignment_grad[assignment_base + free_idx];
}
}
}
}
} else if j == 1 {
let lambda = rho.lambda_smooth();
let frames_active = self.last_frames_active && cache.k == self.factored_border_dim();
let offsets = if frames_active {
self.factored_beta_offsets()
} else {
self.beta_offsets()
};
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
let coeffs = if frames_active {
match &atom.decoder_frame {
Some(frame) => frame.project_decoder(atom.decoder_coefficients.view())?,
None => atom.decoder_coefficients.clone(),
}
} else {
atom.decoder_coefficients.clone()
};
let r = coeffs.ncols();
let off = offsets[atom_idx];
for mu in 0..m {
for channel in 0..r {
let mut acc = 0.0_f64;
for nu in 0..m {
let s_sym = 0.5
* (atom.smooth_penalty[[mu, nu]] + atom.smooth_penalty[[nu, mu]]);
acc += s_sym * coeffs[[nu, channel]];
}
beta[off + mu * r + channel] = lambda * acc;
}
}
}
} else {
let mut cursor = 2usize;
for atom in 0..rho.log_ard.len() {
for axis in 0..rho.log_ard[atom].len() {
if cursor == j {
let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
let periods = self.assignment.coords[atom].effective_axis_periods();
for row in 0..self.n_obs() {
let row_t = self.assignment.coords[atom].row(row);
let prior = ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
let Some(pos) = sae_coord_penalty_offset(
self.last_row_layout.as_ref(),
self.assignment.coord_offsets()[atom] + axis,
row,
atom,
) else {
continue;
};
t[cache.row_offsets[row] + pos] = prior.grad;
}
return Ok(SaeArrowVector { t, beta });
}
cursor += 1;
}
}
}
Ok(SaeArrowVector { t, beta })
}
pub fn logdet_theta_adjoint(
&self,
rho: &SaeManifoldRho,
cache: &ArrowFactorCache,
) -> Result<SaeArrowVector, String> {
let n = self.n_obs();
let total_t = cache.delta_t_len();
let mut gamma_t = Array1::<f64>::zeros(total_t);
let mut gamma_beta = Array1::<f64>::zeros(cache.k);
let second_jets = self.atom_second_jets()?;
let border = self.border_channels_for_cache(cache)?;
let schur_inv = if cache.k > 0 {
cache
.schur_inverse_block(0..cache.k)
.map_err(|err| format!("logdet_theta_adjoint: Schur inverse: {err}"))?
} else {
Array2::<f64>::zeros((0, 0))
};
let ibp_channels = ibp_assignment_third_channels(&self.assignment, rho)?;
let k_atoms = self.k_atoms();
let mut ibp_logit_sites: Vec<(usize, usize, usize, f64)> = Vec::new();
for row in 0..n {
let q = cache.row_dims[row];
let base = cache.row_offsets[row];
let vars = self.row_vars_for_cache_row(row, cache)?;
let assignments = self.assignment.try_assignments_row(row)?;
let jets =
self.row_jets_for_logdet(row, vars, assignments.view(), &second_jets, &border)?;
let mut inv_vv = Array2::<f64>::zeros((q, q));
let mut inv_vbeta = Array2::<f64>::zeros((q, cache.k));
for col in 0..q {
let mut rhs_t = Array1::<f64>::zeros(total_t);
let rhs_beta = Array1::<f64>::zeros(cache.k);
rhs_t[base + col] = 1.0;
let (sol_t, sol_beta) = cache
.full_inverse_apply(rhs_t.view(), rhs_beta.view())
.map_err(|err| {
format!("logdet_theta_adjoint: selected inverse solve: {err}")
})?;
for r in 0..q {
inv_vv[[r, col]] = sol_t[base + r];
}
for b in 0..cache.k {
inv_vbeta[[col, b]] = sol_beta[b];
}
}
if ibp_channels.is_some() {
for (pos, var) in jets.vars.iter().enumerate() {
if let SaeLocalRowVar::Logit { atom } = *var {
ibp_logit_sites.push((row, atom, base + pos, inv_vv[[pos, pos]]));
}
}
}
for w in 0..q {
let mut gamma = 0.0_f64;
for a in 0..q {
for b in 0..q {
let mut dh = sae_dot(&jets.second[a][w], &jets.first[b])
+ sae_dot(&jets.first[a], &jets.second[b][w]);
if a == b {
dh += match jets.vars[a] {
SaeLocalRowVar::Logit { atom } => self
.assignment_prior_hdiag_derivative_entry(
rho,
row,
atom,
jets.vars[w],
ibp_channels.as_ref(),
),
SaeLocalRowVar::Coord { atom, axis } if a == w => {
self.ard_majorized_hessian_derivative(rho, row, atom, axis)
}
_ => 0.0,
};
}
gamma += inv_vv[[b, a]] * dh;
}
}
for a in 0..q {
for (beta_pos, channel) in border.iter().enumerate() {
let dh = sae_dot(&jets.second[a][w], &jets.beta[beta_pos])
+ sae_dot(&jets.first[a], &jets.beta_deriv[w][beta_pos]);
gamma += 2.0 * inv_vbeta[[a, channel.index]] * dh;
}
}
for (beta_i, channel_i) in border.iter().enumerate() {
for (beta_j, channel_j) in border.iter().enumerate() {
let dh = sae_dot(&jets.beta_deriv[w][beta_i], &jets.beta[beta_j])
+ sae_dot(&jets.beta[beta_i], &jets.beta_deriv[w][beta_j]);
gamma += schur_inv[[channel_i.index, channel_j.index]] * dh;
}
}
gamma_t[base + w] = gamma;
}
for (w_beta_pos, w_channel) in border.iter().enumerate() {
let mut gamma = 0.0_f64;
for a in 0..q {
for b in 0..q {
let dh = sae_dot(&jets.beta_l_deriv[a][w_beta_pos], &jets.first[b])
+ sae_dot(&jets.first[a], &jets.beta_l_deriv[b][w_beta_pos]);
gamma += inv_vv[[b, a]] * dh;
}
}
for a in 0..q {
for (beta_pos, channel) in border.iter().enumerate() {
let dh = sae_dot(&jets.beta_l_deriv[a][w_beta_pos], &jets.beta[beta_pos]);
gamma += 2.0 * inv_vbeta[[a, channel.index]] * dh;
}
}
gamma_beta[w_channel.index] += gamma;
}
}
if let Some(channels) = ibp_channels.as_ref() {
let mut col_coeff = vec![0.0_f64; k_atoms];
for &(row, atom, _t_index, inv_diag) in &ibp_logit_sites {
col_coeff[atom] += inv_diag * channels.m_channel[row * k_atoms + atom];
}
for &(row, atom, t_index, _inv_diag) in &ibp_logit_sites {
gamma_t[t_index] += col_coeff[atom] * channels.z_jac[row * k_atoms + atom];
}
}
Ok(SaeArrowVector {
t: gamma_t,
beta: gamma_beta,
})
}
pub fn analytic_outer_rho_gradient_components(
&self,
rho: &SaeManifoldRho,
loss: &SaeManifoldLoss,
cache: &ArrowFactorCache,
) -> Result<SaeOuterRhoGradientComponents, String> {
let n_params = rho.to_flat().len();
let mut explicit = Array1::<f64>::zeros(n_params);
let mut logdet_trace = Array1::<f64>::zeros(n_params);
let mut occam = Array1::<f64>::zeros(n_params);
let mut third_order_correction = Array1::<f64>::zeros(n_params);
explicit[0] = assignment_prior_log_strength_derivative(&self.assignment, rho);
logdet_trace[0] = self.assignment_log_strength_hessian_trace(rho, cache)?;
explicit[1] = loss.smoothness;
logdet_trace[1] = 0.5
* self
.decoder_smoothness_effective_dof(cache, rho.lambda_smooth())
.map_err(|err| format!("analytic_outer_rho_gradient_components: {err}"))?;
occam[1] = -self.reml_occam_log_lambda_smooth_derivative()?;
let ard_explicit = self.ard_log_precision_explicit_derivatives(rho)?;
let ard_trace = self
.ard_log_precision_hessian_trace(rho, cache)
.map_err(|err| format!("analytic_outer_rho_gradient_components: {err}"))?;
let mut cursor = 2usize;
for k in 0..rho.log_ard.len() {
for axis in 0..rho.log_ard[k].len() {
explicit[cursor] = ard_explicit[k][axis];
logdet_trace[cursor] = ard_trace[k][axis];
cursor += 1;
}
}
let gamma = self.logdet_theta_adjoint(rho, cache)?;
for coord in 0..n_params {
let rhs = self.outer_rho_gradient_ift_rhs(rho, coord, cache)?;
let (sol_t, sol_beta) = cache
.full_inverse_apply(rhs.t.view(), rhs.beta.view())
.map_err(|err| {
format!("analytic_outer_rho_gradient_components: full_inverse_apply: {err}")
})?;
let mut dot = 0.0_f64;
for idx in 0..gamma.t.len() {
dot += gamma.t[idx] * sol_t[idx];
}
for idx in 0..gamma.beta.len() {
dot += gamma.beta[idx] * sol_beta[idx];
}
third_order_correction[coord] = -0.5 * dot;
}
Ok(SaeOuterRhoGradientComponents {
explicit,
logdet_trace,
occam,
third_order_correction,
third_order_correction_available: true,
})
}
pub fn criterion_as_atoms(
&mut self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
registry: Option<&AnalyticPenaltyRegistry>,
inner_max_iter: usize,
learning_rate: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
) -> Result<SaeCriterion, String> {
let (_v, loss, cache) = self.reml_criterion_with_cache(
target,
rho,
registry,
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
)?;
let log_det = arrow_log_det_from_cache(&cache).ok_or_else(|| {
"criterion_as_atoms: arrow_log_det_from_cache returned None".to_string()
})?;
let occam = self.reml_occam_term(rho)?;
let extra_penalty_energy = match registry {
Some(reg) => self
.reml_extra_penalty_value_total(reg)
.map_err(|err| format!("SaeManifoldTerm::criterion_as_atoms: {err}"))?,
None => 0.0,
};
let data_fit_priors_value = loss.total() + extra_penalty_energy;
let components = self.analytic_outer_rho_gradient_components(rho, &loss, &cache)?;
Ok(SaeCriterion::assemble(
data_fit_priors_value,
log_det,
occam,
components.explicit,
components.logdet_trace,
components.occam,
components.third_order_correction,
))
}
fn reconstruction_dispersion(
&self,
loss: &SaeManifoldLoss,
cache: &ArrowFactorCache,
rho: &SaeManifoldRho,
) -> Result<f64, String> {
let n = self.n_obs();
let p = self.output_dim();
let n_scalar = (n * p) as f64;
let rss = 2.0 * loss.data_fit;
let smooth_edf = self
.decoder_smoothness_effective_dof(cache, rho.lambda_smooth())
.map_err(|e| format!("reconstruction_dispersion: smooth edf: {e}"))?;
let raw_decoder_dof = if self.frames_active() {
(self.factored_border_dim() + self.grassmann_evidence_dimension()) as f64
} else {
self.beta_dim() as f64
};
let beta_edf = (raw_decoder_dof - smooth_edf).max(0.0);
let traces = self
.ard_inverse_traces(cache)
.map_err(|e| format!("reconstruction_dispersion: ARD traces: {e}"))?;
if rho.log_ard.len() != self.atoms.len() {
return Err(format!(
"reconstruction_dispersion: ρ has {} ARD atoms but term has {}",
rho.log_ard.len(),
self.atoms.len()
));
}
let mut coord_edf = 0.0_f64;
for (k, atom) in self.atoms.iter().enumerate() {
let d_k = atom.latent_dim;
if traces[k].len() != d_k {
return Err(format!(
"reconstruction_dispersion: trace shape mismatch at atom {k} \
(traces={}, d_k={d_k})",
traces[k].len()
));
}
let ard_len = rho.log_ard[k].len();
if ard_len != 0 && ard_len != d_k {
return Err(format!(
"reconstruction_dispersion: ARD shape mismatch at atom {k} \
(log_ard={ard_len}, d_k={d_k})"
));
}
let n_active_k = match self.last_row_layout {
Some(ref layout) => layout
.active_atoms
.iter()
.filter(|active| active.contains(&k))
.count() as f64,
None => n as f64,
};
if ard_len == 0 {
coord_edf += n_active_k * d_k as f64;
continue;
}
for j in 0..d_k {
let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[k][j]);
let edf_kj = (n_active_k - alpha * traces[k][j]).clamp(0.0, n_active_k);
coord_edf += edf_kj;
}
}
let resid_dof = (n_scalar - beta_edf - coord_edf).max(1.0);
let phi = rss / resid_dof;
if !phi.is_finite() || phi < 0.0 {
return Err(format!(
"reconstruction_dispersion: non-finite/negative φ̂={phi} \
(RSS={rss}, resid_dof={resid_dof}, beta_edf={beta_edf}, coord_edf={coord_edf})"
));
}
Ok(phi.max(f64::MIN_POSITIVE))
}
pub fn assemble_shape_uncertainty(
&self,
cache: &ArrowFactorCache,
dispersion: f64,
) -> Result<SaeShapeUncertainty, String> {
let p = self.output_dim();
let frames_active = self.frames_active();
let frame_projection = FrameProjection::new(self);
let block_ranges = if frames_active {
(0..self.k_atoms())
.map(|k| frame_projection.atom_border_range(k))
.collect::<Vec<_>>()
} else {
self.beta_block_offsets().to_vec()
};
let mut atoms = Vec::with_capacity(self.k_atoms());
for (k, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
let cov_block = cache
.schur_inverse_block(block_ranges[k].clone())
.map_err(|e| format!("assemble_shape_uncertainty: atom {k}: {e}"))?;
let n_rows = atom.n_obs();
let d = atom.latent_dim;
let stride = n_rows.div_ceil(SHAPE_BAND_MAX_POINTS).max(1);
let eval_rows: Vec<usize> = (0..n_rows).step_by(stride).collect();
let g = eval_rows.len();
let coords_mat = self.assignment.coords[k].as_matrix();
let mut band_coords = Array2::<f64>::zeros((g, d));
let mut band_mean = Array2::<f64>::zeros((g, p));
let mut band_sd = Array2::<f64>::zeros((g, p));
let mut decoded = vec![0.0_f64; p];
for (gi, &row) in eval_rows.iter().enumerate() {
for axis in 0..d {
band_coords[[gi, axis]] = coords_mat[[row, axis]];
}
atom.fill_decoded_row(row, &mut decoded);
for c in 0..p {
band_mean[[gi, c]] = decoded[c];
}
}
let framed = frames_active && atom.decoder_frame.is_some();
let dense_entries = (m * p).saturating_mul(m * p);
let cov = if framed && dense_entries > SAE_DECODER_COV_PAYLOAD_MAX_ENTRIES {
let mut cov_c = cov_block;
cov_c.mapv_inplace(|v| v * dispersion);
for (gi, &row) in eval_rows.iter().enumerate() {
let basis = atom.basis_values.row(row);
for c in 0..p {
let var = frame_projection.output_variance(k, cov_c.view(), basis, c);
band_sd[[gi, c]] = var.max(0.0).sqrt();
}
}
None
} else {
let mut cov = if framed {
frame_projection.lift_block(k, cov_block.view())
} else {
cov_block
};
cov.mapv_inplace(|v| v * dispersion);
for (gi, &row) in eval_rows.iter().enumerate() {
for c in 0..p {
let var = frame_projection.full_output_variance(
k,
cov.view(),
atom.basis_values.row(row),
c,
);
band_sd[[gi, c]] = var.max(0.0).sqrt();
}
}
Some(cov)
};
atoms.push(SaeAtomShapeUncertainty {
decoder_covariance: cov,
band_coords,
band_mean,
band_sd,
});
}
Ok(SaeShapeUncertainty { dispersion, atoms })
}
fn add_sae_analytic_penalty_contributions(
&self,
sys: &mut ArrowSchurSystem,
registry: &AnalyticPenaltyRegistry,
penalty_scale: f64,
row_layout: Option<&SaeRowLayout>,
dense_beta_curvature: bool,
factored_row_projection: Option<&FrameProjection>,
) -> Result<SaeBetaPenaltyAssembly, ArrowSchurError> {
let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
let layout = registry.rho_layout();
let logits_flat = flat_logits(self.assignment.logits.view());
let beta = self.flatten_beta();
let mut beta_assembly = SaeBetaPenaltyAssembly::default();
for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(layout.iter()) {
let rho_local = rho_global.slice(s![rho_slice.clone()]);
if matches!(penalty, AnalyticPenaltyKind::Ard(_)) {
continue;
}
match tier {
PenaltyTier::Psi => {
if matches!(
penalty,
AnalyticPenaltyKind::IBPAssignment(_)
| AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
) {
self.add_sae_logit_penalty(
sys,
penalty,
logits_flat.view(),
rho_local,
row_layout,
);
} else if matches!(penalty, AnalyticPenaltyKind::NuclearNorm(_)) {
if self.add_sae_beta_penalty(
sys,
penalty,
beta.view(),
rho_local,
penalty_scale,
dense_beta_curvature,
) {
beta_assembly.record_curvature(dense_beta_curvature);
}
} else {
assert!(
sae_penalty_is_row_block_supported(penalty),
"validate_analytic_penalty_registry should have \
refused non-row-block Psi-tier penalty {:?} \
(registry layout name {name:?})",
penalty.name()
);
let offsets = self.assignment.coord_offsets();
for atom_idx in 0..self.k_atoms() {
let off = offsets[atom_idx];
let coord = &self.assignment.coords[atom_idx];
if let AnalyticPenaltyKind::Isometry(iso) = penalty {
let corrected_kind =
self.corrected_isometry_penalty(iso, atom_idx, coord)?;
self.add_sae_coord_penalty(
sys,
atom_idx,
off,
coord,
&corrected_kind,
rho_local,
row_layout,
factored_row_projection,
);
if let AnalyticPenaltyKind::Isometry(corrected) = &corrected_kind {
self.add_sae_isometry_beta_penalty(
sys,
atom_idx,
coord,
corrected,
rho_local,
dense_beta_curvature,
);
beta_assembly.record_curvature(dense_beta_curvature);
}
} else {
self.add_sae_coord_penalty(
sys,
atom_idx,
off,
coord,
penalty,
rho_local,
row_layout,
factored_row_projection,
);
}
}
}
}
PenaltyTier::Beta => {
if self.add_sae_beta_penalty(
sys,
penalty,
beta.view(),
rho_local,
penalty_scale,
dense_beta_curvature,
) {
beta_assembly.record_curvature(dense_beta_curvature);
}
}
PenaltyTier::Rho => {}
}
}
Ok(beta_assembly)
}
fn corrected_isometry_penalty(
&self,
iso: &Arc<IsometryPenalty>,
atom_idx: usize,
coord: &LatentCoordValues,
) -> Result<AnalyticPenaltyKind, ArrowSchurError> {
let atom = &self.atoms[atom_idx];
let p = atom.decoder_coefficients.ncols();
let mut corrected: IsometryPenalty = (**iso).clone();
corrected.p_out = p;
if let Some(metric) = self.row_metric.as_ref() {
if metric.drives_gauge() {
if metric.p_out() == p {
corrected.weight = metric.to_weight_field();
} else {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"corrected_isometry_penalty: RowMetric p_out {} disagrees with atom \
{} decoder output dim {p}; the gauge metric must match the likelihood \
metric",
metric.p_out(),
atom_idx
),
});
}
}
}
let coords_mat = coord.as_matrix();
let second_jet_installed =
refresh_isometry_caches_from_atom(&corrected, atom, coords_mat.view())
.map_err(|reason| ArrowSchurError::SchurFactorFailed { reason })?;
if !second_jet_installed {
match atom
.basis_evaluator
.as_ref()
.and_then(|e| e.second_jet_dyn(coords_mat.view()))
{
Some(Ok(hess)) => {
let n_obs = coords_mat.nrows();
let d = atom.latent_dim;
let m = atom.basis_size();
if hess.dim() != (n_obs, m, d, d) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"SAE Isometry atom '{}': second_jet_dyn returned shape {:?}, \
expected ({n_obs}, {m}, {d}, {d})",
atom.name,
hess.dim()
),
});
}
let b = &atom.decoder_coefficients;
let mut jac2 = Array2::<f64>::zeros((n_obs, p * d * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
for c in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += hess[[n, mm, a, c]] * b[[mm, i]];
}
jac2[[n, (i * d + a) * d + c]] = acc;
}
}
}
}
corrected.set_jacobian_second_cache(Some(Arc::new(jac2)));
}
Some(Err(reason)) => {
return Err(ArrowSchurError::SchurFactorFailed { reason });
}
None => {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"IsometryPenalty requested for SAE atom '{}' (basis kind {:?}) but \
this evaluator does not expose an analytic second jet; use \
AffineCoordinateEvaluator, SphereChartEvaluator, \
PeriodicHarmonicEvaluator, or TorusHarmonicEvaluator for \
SAE-Isometry",
atom.name, atom.basis_kind
),
});
}
}
}
Ok(AnalyticPenaltyKind::Isometry(Arc::new(corrected)))
}
fn add_sae_logit_penalty(
&self,
sys: &mut ArrowSchurSystem,
penalty: &AnalyticPenaltyKind,
target: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
row_layout: Option<&SaeRowLayout>,
) {
let n = self.n_obs();
let k = self.k_atoms();
let assignment_dim = self.assignment.assignment_coord_dim();
let grad = penalty.grad_target(target, rho_local);
for row in 0..n {
if let Some(layout) = row_layout {
for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
sys.rows[row].gt[pos] += grad[row * k + atom];
}
} else {
for free_idx in 0..assignment_dim {
sys.rows[row].gt[free_idx] += grad[row * k + free_idx];
}
}
}
if let Some(diag) = penalty.psd_majorizer_diag(target, rho_local) {
for row in 0..n {
if let Some(layout) = row_layout {
for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
sys.rows[row].htt[[pos, pos]] += diag[row * k + atom];
}
} else {
for free_idx in 0..assignment_dim {
sys.rows[row].htt[[free_idx, free_idx]] += diag[row * k + free_idx];
}
}
}
}
}
fn add_sae_coord_penalty(
&self,
sys: &mut ArrowSchurSystem,
atom_idx: usize,
dense_off: usize,
coord: &LatentCoordValues,
penalty: &AnalyticPenaltyKind,
rho_local: ArrayView1<'_, f64>,
row_layout: Option<&SaeRowLayout>,
factored_row_projection: Option<&FrameProjection>,
) {
let n = coord.n_obs();
let d = coord.latent_dim();
if sae_coord_penalty_is_origin_anchored_magnitude(penalty) {
if let Some((euclidean_axes, compacted)) =
sae_coord_penalty_euclidean_restriction(coord)
{
let de = euclidean_axes.len();
let grad = penalty.grad_target(compacted.view(), rho_local);
let diag = penalty.psd_majorizer_diag(compacted.view(), rho_local);
for row in 0..n {
if let Some(row_off) =
sae_coord_penalty_offset(row_layout, dense_off, row, atom_idx)
{
for (j, &axis) in euclidean_axes.iter().enumerate() {
sys.rows[row].gt[row_off + axis] += grad[row * de + j];
if let Some(diag) = diag.as_ref() {
sys.rows[row].htt[[row_off + axis, row_off + axis]] +=
diag[row * de + j];
}
}
}
}
return;
}
}
let target = coord.as_flat().view();
let grad = penalty.grad_target(target, rho_local);
for row in 0..n {
if let Some(row_off) = sae_coord_penalty_offset(row_layout, dense_off, row, atom_idx) {
for axis in 0..d {
sys.rows[row].gt[row_off + axis] += grad[row * d + axis];
}
}
}
if let AnalyticPenaltyKind::Isometry(corrected) = penalty {
self.add_sae_isometry_metric_gn_blocks(
sys,
atom_idx,
dense_off,
coord,
corrected,
rho_local,
row_layout,
factored_row_projection,
);
return;
}
if let Some(diag) = penalty.psd_majorizer_diag(target, rho_local) {
for row in 0..n {
if let Some(row_off) =
sae_coord_penalty_offset(row_layout, dense_off, row, atom_idx)
{
for axis in 0..d {
sys.rows[row].htt[[row_off + axis, row_off + axis]] += diag[row * d + axis];
}
}
}
return;
}
let mut probe = Array1::<f64>::zeros(n * d);
for axis in 0..d {
probe.fill(0.0);
for row in 0..n {
probe[row * d + axis] = 1.0;
}
let hv = penalty.psd_majorizer_hvp(target, rho_local, probe.view());
for row in 0..n {
if let Some(row_off) =
sae_coord_penalty_offset(row_layout, dense_off, row, atom_idx)
{
for b in 0..d {
sys.rows[row].htt[[row_off + b, row_off + axis]] += hv[row * d + b];
}
}
}
}
}
fn add_sae_isometry_metric_gn_blocks(
&self,
sys: &mut ArrowSchurSystem,
atom_idx: usize,
dense_off: usize,
coord: &LatentCoordValues,
corrected: &Arc<IsometryPenalty>,
rho_local: ArrayView1<'_, f64>,
row_layout: Option<&SaeRowLayout>,
factored_row_projection: Option<&FrameProjection>,
) {
let n_obs = coord.n_obs();
let d = coord.latent_dim();
let atom = &self.atoms[atom_idx];
let p = atom.decoder_coefficients.ncols();
let m = atom.basis_size();
let Some(jac) = corrected.jacobian_cache() else {
return;
};
if jac.dim() != (n_obs, p * d) {
return;
}
let Some(jac2) = corrected.jacobian_second_cache() else {
return;
};
if jac2.dim() != (n_obs, p * d * d) {
return;
}
let beta_off = self.beta_offsets()[atom_idx];
let beta_block = m * p;
let jet = &atom.basis_jacobian;
let mu = resolve_learnable_weight(corrected.scalar_weight, rho_local[corrected.rho_index]);
if !(mu.is_finite() && mu > 0.0) {
return;
}
let couple_cross_block = coord.manifold().preserves_isometry_cross_block_coherence();
let mut metric_coord_jac = Array2::<f64>::zeros((d * d, d));
let mut metric_beta_jac = Array2::<f64>::zeros((d * d, beta_block));
let mut wrote_dense_cross = false;
for row in 0..n_obs {
let Some(row_off) = sae_coord_penalty_offset(row_layout, dense_off, row, atom_idx)
else {
continue;
};
let Some(wj) = Self::sae_isometry_weighted_jacobian_row(corrected, &jac, row, p, d)
else {
return;
};
metric_coord_jac.fill(0.0);
for a in 0..d {
for b in 0..d {
let metric_row = a * d + b;
for c in 0..d {
let mut acc = 0.0;
for i in 0..p {
acc += jac2[[row, (i * d + a) * d + c]] * wj[[i, b]];
acc += wj[[i, a]] * jac2[[row, (i * d + b) * d + c]];
}
metric_coord_jac[[metric_row, c]] = acc;
}
}
}
if couple_cross_block {
metric_beta_jac.fill(0.0);
for a in 0..d {
for b in 0..d {
let metric_row = a * d + b;
for basis_col in 0..m {
let jet_a = jet[[row, basis_col, a]];
let jet_b = jet[[row, basis_col, b]];
for output in 0..p {
metric_beta_jac[[metric_row, basis_col * p + output]] =
jet_a * wj[[output, b]] + wj[[output, a]] * jet_b;
}
}
}
}
}
for c in 0..d {
for e in 0..d {
let mut acc = 0.0;
for metric_row in 0..(d * d) {
acc +=
metric_coord_jac[[metric_row, c]] * metric_coord_jac[[metric_row, e]];
}
sys.rows[row].htt[[row_off + c, row_off + e]] += mu * acc;
}
if !couple_cross_block {
continue;
}
for beta_col in 0..beta_block {
let mut acc = 0.0;
for metric_row in 0..(d * d) {
acc += metric_coord_jac[[metric_row, c]]
* metric_beta_jac[[metric_row, beta_col]];
}
if let Some(projection) = factored_row_projection {
let basis_col = beta_col / p;
let output = beta_col % p;
let c_base = projection.border_offsets[atom_idx]
+ basis_col * projection.ranks[atom_idx];
let mut hrow = sys.rows[row].htbeta.row_mut(row_off + c);
let hrow_slice = hrow.as_slice_mut().expect("htbeta row is contiguous");
projection.accumulate_output_project(
atom_idx,
c_base,
output,
mu * acc,
hrow_slice,
);
} else {
sys.rows[row].htbeta[[row_off + c, beta_off + beta_col]] += mu * acc;
}
wrote_dense_cross = true;
}
}
}
if wrote_dense_cross {
sys.activate_dense_htbeta_supplement();
}
}
fn sae_isometry_weighted_jacobian_row(
corrected: &IsometryPenalty,
jac: &Array2<f64>,
row: usize,
p: usize,
d: usize,
) -> Option<Array2<f64>> {
match &corrected.weight {
WeightField::Identity => {
let mut out = Array2::<f64>::zeros((p, d));
for i in 0..p {
for a in 0..d {
out[[i, a]] = jac[[row, i * d + a]];
}
}
Some(out)
}
WeightField::Factored { u, rank, p_out } => {
if *p_out != p || u.nrows() != jac.nrows() || u.ncols() != p * *rank {
return None;
}
let mut projected = Array2::<f64>::zeros((*rank, d));
for weight_axis in 0..*rank {
for a in 0..d {
let mut acc = 0.0;
for i in 0..p {
acc += u[[row, i * *rank + weight_axis]] * jac[[row, i * d + a]];
}
projected[[weight_axis, a]] = acc;
}
}
let mut out = Array2::<f64>::zeros((p, d));
for i in 0..p {
for a in 0..d {
let mut acc = 0.0;
for weight_axis in 0..*rank {
acc += u[[row, i * *rank + weight_axis]] * projected[[weight_axis, a]];
}
out[[i, a]] = acc;
}
}
Some(out)
}
}
}
fn add_sae_isometry_beta_penalty(
&self,
sys: &mut ArrowSchurSystem,
atom_idx: usize,
coord: &LatentCoordValues,
corrected: &Arc<IsometryPenalty>,
rho_local: ArrayView1<'_, f64>,
dense_beta_curvature: bool,
) {
let atom = &self.atoms[atom_idx];
let d = coord.latent_dim();
let p = atom.decoder_coefficients.ncols();
let m = atom.basis_size();
let n_obs = coord.n_obs();
let grad_jac = corrected.grad_jacobian(coord.as_flat().view(), rho_local);
if grad_jac.dim() != (n_obs, p * d) {
return;
}
let jet = &atom.basis_jacobian;
let beta_off = self.beta_offsets()[atom_idx];
for basis_col in 0..m {
for i in 0..p {
let mut acc = 0.0;
for n in 0..n_obs {
for a in 0..d {
acc += grad_jac[[n, i * d + a]] * jet[[n, basis_col, a]];
}
}
sys.gb[beta_off + basis_col * p + i] += acc;
}
}
if !dense_beta_curvature {
return;
}
let Some(jac) = corrected.jacobian_cache() else {
return;
};
if jac.dim() != (n_obs, p * d) {
return;
}
let mut weighted_jacobian_rows = Vec::with_capacity(n_obs);
for n in 0..n_obs {
let Some(wj) = Self::sae_isometry_weighted_jacobian_row(corrected, &jac, n, p, d)
else {
return;
};
weighted_jacobian_rows.push(wj);
}
let mu = resolve_learnable_weight(corrected.scalar_weight, rho_local[corrected.rho_index]);
let mut metric_jvp = Array2::<f64>::zeros((d, d));
let mut jac_hvp = Array2::<f64>::zeros((p, d));
let mut beta_hvp = Array2::<f64>::zeros((m, p));
for probe_basis_col in 0..m {
for probe_output in 0..p {
beta_hvp.fill(0.0);
for n in 0..n_obs {
let wj = &weighted_jacobian_rows[n];
metric_jvp.fill(0.0);
for a in 0..d {
let probe_jet_a = jet[[n, probe_basis_col, a]];
for b in 0..d {
metric_jvp[[a, b]] = probe_jet_a * wj[[probe_output, b]]
+ wj[[probe_output, a]] * jet[[n, probe_basis_col, b]];
}
}
jac_hvp.fill(0.0);
for i in 0..p {
for c in 0..d {
let mut acc = 0.0;
for b in 0..d {
acc += metric_jvp[[c, b]] * wj[[i, b]];
}
for a in 0..d {
acc += metric_jvp[[a, c]] * wj[[i, a]];
}
jac_hvp[[i, c]] = mu * acc;
}
}
for basis_row in 0..m {
for i in 0..p {
let mut acc = 0.0;
for a in 0..d {
acc += jac_hvp[[i, a]] * jet[[n, basis_row, a]];
}
beta_hvp[[basis_row, i]] += acc;
}
}
}
let beta_col = beta_off + probe_basis_col * p + probe_output;
for basis_row in 0..m {
for i in 0..p {
sys.hbb[[beta_off + basis_row * p + i, beta_col]] +=
beta_hvp[[basis_row, i]];
}
}
}
}
}
fn live_decoder_incoherence_penalty(
&self,
base: &Arc<DecoderIncoherencePenalty>,
) -> Option<DecoderIncoherencePenalty> {
let k_atoms = self.k_atoms();
if k_atoms < 2 {
return None;
}
let p = self.output_dim();
let block_sizes: Vec<usize> = self.atoms.iter().map(|atom| atom.basis_size()).collect();
let m_total: usize = block_sizes.iter().sum();
let gates = self.assignment.assignments();
let n = gates.nrows();
let inv_n = if n > 0 { 1.0 / n as f64 } else { 0.0 };
let mut coactivation = Array2::<f64>::zeros((k_atoms, k_atoms));
for j in 0..k_atoms {
for k in 0..k_atoms {
let mut s = 0.0;
for row in 0..n {
s += gates[[row, j]] * gates[[row, k]];
}
coactivation[[j, k]] = s * inv_n;
}
}
let mut per_fit: DecoderIncoherencePenalty = (**base).clone();
per_fit.block_sizes = block_sizes;
per_fit.p_out = p;
per_fit.target = PsiSlice {
range: 0..m_total * p,
latent_dim: Some(m_total),
};
per_fit.coactivation = coactivation;
Some(per_fit)
}
fn live_mechanism_sparsity_penalties(
&self,
base: &Arc<MechanismSparsityPenalty>,
) -> Vec<(MechanismSparsityPenalty, usize, usize)> {
let beta_offsets = self.beta_offsets();
let p = self.output_dim();
let mut out = Vec::with_capacity(self.atoms.len());
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
let start = beta_offsets[atom_idx];
let end = start + m * p;
let mut per_atom: MechanismSparsityPenalty = (**base).clone();
per_atom.target = PsiSlice {
range: start..end,
latent_dim: Some(m),
};
out.push((per_atom, start, end));
}
out
}
fn live_nuclear_norm_penalties(
&self,
base: &Arc<NuclearNormPenalty>,
) -> Vec<(NuclearNormPenalty, usize, usize)> {
let beta_offsets = self.beta_offsets();
let p = self.output_dim();
let mut out = Vec::with_capacity(self.atoms.len());
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
let start = beta_offsets[atom_idx];
let end = start + m * p;
let mut per_atom: NuclearNormPenalty = (**base).clone();
per_atom.n_eff = m;
per_atom.target = PsiSlice {
range: start..end,
latent_dim: Some(p),
};
out.push((per_atom, start, end));
}
out
}
fn add_sae_beta_penalty(
&self,
sys: &mut ArrowSchurSystem,
penalty: &AnalyticPenaltyKind,
target_beta: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
penalty_scale: f64,
dense_beta_curvature: bool,
) -> bool {
if let AnalyticPenaltyKind::DecoderIncoherence(base) = penalty {
let Some(per_fit) = self.live_decoder_incoherence_penalty(base) else {
return false;
};
let beta_dim = self.beta_dim();
let grad = per_fit.grad_target(target_beta, rho_local);
for j in 0..beta_dim {
sys.gb[j] += penalty_scale * grad[j];
}
if !dense_beta_curvature {
return true;
}
let mut probe = Array1::<f64>::zeros(beta_dim);
for j in 0..beta_dim {
probe.fill(0.0);
probe[j] = 1.0;
let hv = per_fit.psd_majorizer_hvp(target_beta, rho_local, probe.view());
for i in 0..beta_dim {
sys.hbb[[i, j]] += penalty_scale * hv[i];
}
}
return true;
}
if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
let mut any = false;
for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
any |= self.add_sae_mech_sparsity_atom(
sys,
&per_atom,
target_beta,
rho_local,
start,
end,
penalty_scale,
dense_beta_curvature,
);
}
return any;
}
if let AnalyticPenaltyKind::NuclearNorm(base) = penalty {
let mut any = false;
for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
any |= self.add_sae_nuclear_norm_atom(
sys,
&per_atom,
target_beta,
rho_local,
start,
end,
penalty_scale,
dense_beta_curvature,
);
}
return any;
}
let k = self.beta_dim();
let grad = penalty.grad_target(target_beta, rho_local);
for j in 0..k {
sys.gb[j] += penalty_scale * grad[j];
}
if !dense_beta_curvature {
return true;
}
if let Some(diag) = penalty.psd_majorizer_diag(target_beta, rho_local) {
for j in 0..k {
sys.hbb[[j, j]] += penalty_scale * diag[j];
}
return true;
}
let mut probe = Array1::<f64>::zeros(k);
for j in 0..k {
probe.fill(0.0);
probe[j] = 1.0;
let hv = penalty.psd_majorizer_hvp(target_beta, rho_local, probe.view());
for i in 0..k {
sys.hbb[[i, j]] += penalty_scale * hv[i];
}
}
true
}
fn add_sae_mech_sparsity_atom(
&self,
sys: &mut ArrowSchurSystem,
per_atom: &MechanismSparsityPenalty,
target_beta: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
start: usize,
end: usize,
penalty_scale: f64,
dense_beta_curvature: bool,
) -> bool {
let grad = per_atom.grad_target(target_beta, rho_local);
for j in start..end {
sys.gb[j] += penalty_scale * grad[j];
}
if !dense_beta_curvature {
return true;
}
let k = self.beta_dim();
let mut probe = Array1::<f64>::zeros(k);
for j in start..end {
probe.fill(0.0);
probe[j] = 1.0;
let hv = per_atom.psd_majorizer_hvp(target_beta, rho_local, probe.view());
for i in start..end {
sys.hbb[[i, j]] += penalty_scale * hv[i];
}
}
true
}
fn add_sae_nuclear_norm_atom(
&self,
sys: &mut ArrowSchurSystem,
per_atom: &NuclearNormPenalty,
target_beta: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
start: usize,
end: usize,
penalty_scale: f64,
dense_beta_curvature: bool,
) -> bool {
let block = target_beta.slice(s![start..end]);
let block_len = end - start;
let grad = per_atom.grad_target(block, rho_local);
for local in 0..block_len {
sys.gb[start + local] += penalty_scale * grad[local];
}
if !dense_beta_curvature {
return true;
}
let mut probe = Array1::<f64>::zeros(block_len);
for local in 0..block_len {
probe.fill(0.0);
probe[local] = 1.0;
let hv = per_atom.psd_majorizer_hvp(block, rho_local, probe.view());
for i in 0..block_len {
sys.hbb[[start + i, start + local]] += penalty_scale * hv[i];
}
}
true
}
pub fn solve_newton_step(
&mut self,
target: ArrayView2<'_, f64>,
rho: &SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
ridge_ext_coord: f64,
ridge_beta: f64,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
let sys = self
.assemble_arrow_schur(target, rho, analytic_penalties)
.map_err(|reason| ArrowSchurError::SchurFactorFailed { reason })?;
sys.solve_with_lm_escalation(ridge_ext_coord, ridge_beta)
.map(|(delta_t, delta_beta, _diag)| (delta_t, delta_beta))
}
pub fn apply_newton_step(
&mut self,
delta_ext_coord: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
step_size: f64,
) -> Result<(), String> {
self.apply_newton_step_impl(delta_ext_coord, delta_beta, step_size, true)
}
pub fn apply_newton_step_external_basis_refresh(
&mut self,
delta_ext_coord: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
step_size: f64,
) -> Result<(), String> {
self.apply_newton_step_impl(delta_ext_coord, delta_beta, step_size, false)
}
fn snapshot_mutable_state(&self) -> SaeManifoldMutableState {
let atoms = self
.atoms
.iter()
.map(|atom| {
(
atom.basis_values.clone(),
atom.basis_jacobian.clone(),
atom.decoder_coefficients.clone(),
atom.smooth_penalty.clone(),
)
})
.collect();
SaeManifoldMutableState {
atoms,
logits: self.assignment.logits.clone(),
coords: self.assignment.coords.clone(),
last_row_layout: self.last_row_layout.clone(),
}
}
fn restore_mutable_state(&mut self, snapshot: &SaeManifoldMutableState) {
for (atom, (basis_values, basis_jacobian, decoder, smooth_penalty)) in
self.atoms.iter_mut().zip(snapshot.atoms.iter())
{
atom.basis_values.assign(basis_values);
atom.basis_jacobian.assign(basis_jacobian);
atom.decoder_coefficients.assign(decoder);
atom.smooth_penalty.assign(smooth_penalty);
}
self.assignment.logits.assign(&snapshot.logits);
self.assignment.coords.clone_from(&snapshot.coords);
self.last_row_layout.clone_from(&snapshot.last_row_layout);
}
fn refresh_basis_from_current_coords(&mut self) -> Result<(), String> {
for atom_idx in 0..self.k_atoms() {
let coords = self.assignment.coords[atom_idx].as_matrix();
self.atoms[atom_idx].refresh_basis(coords.view())?;
}
Ok(())
}
fn canonicalize_affine_gauge_after_accept(&mut self) -> Result<(), String> {
for atom_idx in 0..self.k_atoms() {
if !matches!(
self.atoms[atom_idx].basis_kind,
SaeAtomBasisKind::EuclideanPatch | SaeAtomBasisKind::Duchon
) {
continue;
}
self.canonicalize_atom_affine_gauge(atom_idx)?;
}
Ok(())
}
fn canonicalize_atom_affine_gauge(&mut self, atom_idx: usize) -> Result<(), String> {
let n = self.n_obs();
let d = self.assignment.coords[atom_idx].latent_dim();
if n == 0 || d == 0 {
return Ok(());
}
let Some(evaluator) = self.atoms[atom_idx].basis_evaluator.as_ref() else {
return Ok(());
};
let coords = self.assignment.coords[atom_idx].as_matrix();
let weights = self.atom_affine_gauge_weights(atom_idx)?;
let weight_sum: f64 = weights.iter().sum();
if !(weight_sum.is_finite() && weight_sum > 0.0) {
return Ok(());
}
let mut shift = vec![0.0_f64; d];
for row in 0..n {
let w = weights[row];
for axis in 0..d {
shift[axis] += w * coords[[row, axis]];
}
}
for value in &mut shift {
*value /= weight_sum;
}
let mut scale = vec![1.0_f64; d];
let mut changed = false;
for axis in 0..d {
let mut var = 0.0_f64;
for row in 0..n {
let centered = coords[[row, axis]] - shift[axis];
var += weights[row] * centered * centered;
}
let rms = (var / weight_sum).sqrt();
if rms.is_finite() && rms > 1.0e-12 {
scale[axis] = rms;
}
if shift[axis].abs() > 1.0e-12 || (scale[axis] - 1.0).abs() > 1.0e-12 {
changed = true;
}
}
if !changed {
return Ok(());
}
let Some(new_evaluator) = evaluator.affine_transformed_evaluator(
&shift,
&scale,
self.atoms[atom_idx].basis_size(),
)?
else {
return Ok(());
};
let mut new_coords = coords.clone();
for row in 0..n {
for axis in 0..d {
new_coords[[row, axis]] = (coords[[row, axis]] - shift[axis]) / scale[axis];
}
}
let (new_phi, new_jet) = if self.atoms[atom_idx].homotopy_eta == 1.0 {
new_evaluator.evaluate(new_coords.view())?
} else {
let evaluated = new_evaluator
.evaluate_phi_eta(new_coords.view(), self.atoms[atom_idx].homotopy_eta)?;
(evaluated.phi, evaluated.jet)
};
let old_phi = self.atoms[atom_idx].basis_values.clone();
if new_phi.dim() != old_phi.dim() {
return Err(format!(
"SaeManifoldTerm::canonicalize_atom_affine_gauge: transformed basis shape {:?} != {:?}",
new_phi.dim(),
old_phi.dim()
));
}
let transport = solve_basis_transport(new_phi.view(), old_phi.view())?;
let old_decoder = self.atoms[atom_idx].decoder_coefficients.clone();
let new_decoder = fast_ab(&transport, &old_decoder);
let old_fit = fast_ab(&old_phi, &old_decoder);
let new_fit = fast_ab(&new_phi, &new_decoder);
let fit_scale = old_fit
.iter()
.chain(new_fit.iter())
.fold(1.0_f64, |acc, &v| acc.max(v.abs()));
let max_abs = old_fit
.iter()
.zip(new_fit.iter())
.fold(0.0_f64, |acc, (&a, &b)| acc.max((a - b).abs()));
if max_abs > 1.0e-8 * fit_scale {
return Ok(());
}
let flat = Array1::from_iter(new_coords.iter().copied());
self.assignment.coords[atom_idx].set_flat(flat.view());
let atom = &mut self.atoms[atom_idx];
atom.basis_values = new_phi;
atom.basis_jacobian = new_jet;
atom.decoder_coefficients = new_decoder;
let base: Arc<dyn SaeBasisEvaluator> = new_evaluator.clone();
atom.basis_evaluator = Some(base);
atom.basis_second_jet = Some(new_evaluator);
atom.refresh_intrinsic_smooth_penalty();
Ok(())
}
fn atom_affine_gauge_weights(&self, atom_idx: usize) -> Result<Array1<f64>, String> {
let n = self.n_obs();
let mut weights = Array1::<f64>::zeros(n);
for row in 0..n {
let assignments = self.assignment.try_assignments_row(row)?;
let mut w = assignments[atom_idx].max(0.0);
if let Some(row_weights) = self.row_loss_weights.as_ref() {
w *= row_weights[row].max(0.0);
}
weights[row] = if w.is_finite() { w } else { 0.0 };
}
Ok(weights)
}
fn quotient_newton_step_norm_sq(
&self,
delta_ext_coord: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
raw_step_norm_sq: f64,
) -> Result<f64, String> {
let n = self.n_obs();
let q = self.assignment.row_block_dim();
let beta_dim = self.beta_dim();
if delta_ext_coord.len() != n * q || delta_beta.len() != beta_dim {
return Ok(raw_step_norm_sq);
}
let mut residual = Array1::<f64>::zeros(delta_ext_coord.len() + delta_beta.len());
for i in 0..delta_ext_coord.len() {
residual[i] = delta_ext_coord[i];
}
let beta_base = delta_ext_coord.len();
for i in 0..delta_beta.len() {
residual[beta_base + i] = delta_beta[i];
}
let mut orthonormal: Vec<Array1<f64>> = Vec::new();
for mut gauge in self.dense_step_gauge_vectors()? {
for basis in &orthonormal {
let coeff = gauge.dot(basis);
for i in 0..gauge.len() {
gauge[i] -= coeff * basis[i];
}
}
let norm_sq = gauge.iter().map(|v| v * v).sum::<f64>();
if norm_sq <= 1.0e-24 || !norm_sq.is_finite() {
continue;
}
let inv_norm = norm_sq.sqrt().recip();
for v in gauge.iter_mut() {
*v *= inv_norm;
}
let coeff = residual.dot(&gauge);
for i in 0..residual.len() {
residual[i] -= coeff * gauge[i];
}
orthonormal.push(gauge);
}
let quotient = residual.iter().map(|v| v * v).sum::<f64>();
Ok(if quotient.is_finite() {
quotient.max(0.0).min(raw_step_norm_sq)
} else {
raw_step_norm_sq
})
}
fn dense_step_gauge_vectors(&self) -> Result<Vec<Array1<f64>>, String> {
let n = self.n_obs();
let q = self.assignment.row_block_dim();
let p = self.output_dim();
let coord_offsets = self.assignment.coord_offsets();
let beta_offsets = self.beta_offsets();
let total_len = n * q + self.beta_dim();
let mut out = Vec::new();
for atom_idx in 0..self.k_atoms() {
let d = self.assignment.coords[atom_idx].latent_dim();
let coords = self.assignment.coords[atom_idx].as_matrix();
match self.atoms[atom_idx].basis_kind {
SaeAtomBasisKind::EuclideanPatch => {
for axis in 0..d {
let mut field = Array2::<f64>::zeros((n, d));
field.column_mut(axis).fill(1.0);
if let Some(g) = self.dense_step_gauge_vector_from_field(
atom_idx,
field.view(),
&coord_offsets,
&beta_offsets,
total_len,
)? {
out.push(g);
}
}
for axis in 0..d {
let mut field = Array2::<f64>::zeros((n, d));
for row in 0..n {
field[[row, axis]] = coords[[row, axis]];
}
if let Some(g) = self.dense_step_gauge_vector_from_field(
atom_idx,
field.view(),
&coord_offsets,
&beta_offsets,
total_len,
)? {
out.push(g);
}
}
}
SaeAtomBasisKind::Duchon if d == 1 => {
let mut translation = Array2::<f64>::ones((n, 1));
if let Some(g) = self.dense_step_gauge_vector_from_field(
atom_idx,
translation.view(),
&coord_offsets,
&beta_offsets,
total_len,
)? {
out.push(g);
}
for row in 0..n {
translation[[row, 0]] = coords[[row, 0]];
}
if let Some(g) = self.dense_step_gauge_vector_from_field(
atom_idx,
translation.view(),
&coord_offsets,
&beta_offsets,
total_len,
)? {
out.push(g);
}
}
SaeAtomBasisKind::Periodic | SaeAtomBasisKind::Torus => {
for axis in 0..d {
let mut field = Array2::<f64>::zeros((n, d));
field.column_mut(axis).fill(1.0);
if let Some(g) = self.dense_step_gauge_vector_from_field(
atom_idx,
field.view(),
&coord_offsets,
&beta_offsets,
total_len,
)? {
out.push(g);
}
}
}
_ => {}
}
}
if p == 0 {
return Ok(Vec::new());
}
Ok(out)
}
fn dense_step_gauge_vector_from_field(
&self,
atom_idx: usize,
field: ArrayView2<'_, f64>,
coord_offsets: &[usize],
beta_offsets: &[usize],
total_len: usize,
) -> Result<Option<Array1<f64>>, String> {
let n = self.n_obs();
let q = self.assignment.row_block_dim();
let p = self.output_dim();
let atom = &self.atoms[atom_idx];
let m = atom.basis_size();
let d = self.assignment.coords[atom_idx].latent_dim();
if field.dim() != (n, d) {
return Err(format!(
"dense_step_gauge_vector_from_field: field shape {:?} != ({n}, {d})",
field.dim()
));
}
let mut design = Array2::<f64>::zeros((n, m));
let mut motion = Array2::<f64>::zeros((n, p));
for row in 0..n {
let assignments = self.assignment.try_assignments_row(row)?;
let a = assignments[atom_idx];
if a == 0.0 {
continue;
}
for col in 0..m {
design[[row, col]] = a * atom.basis_values[[row, col]];
}
for axis in 0..d {
let dt = field[[row, axis]];
if dt == 0.0 {
continue;
}
for col in 0..m {
let w = a * dt * atom.basis_jacobian[[row, col, axis]];
if w == 0.0 {
continue;
}
for out_col in 0..p {
motion[[row, out_col]] += w * atom.decoder_coefficients[[col, out_col]];
}
}
}
}
let raw = motion.iter().map(|v| v * v).sum::<f64>();
if raw <= f64::MIN_POSITIVE || !raw.is_finite() {
return Ok(None);
}
motion.mapv_inplace(|v| -v);
let delta_b = solve_design_least_squares(design.view(), motion.view())?;
let mut gauge = Array1::<f64>::zeros(total_len);
for row in 0..n {
let row_base = row * q + coord_offsets[atom_idx];
for axis in 0..d {
gauge[row_base + axis] = field[[row, axis]];
}
}
let beta_base = n * q + beta_offsets[atom_idx];
for col in 0..m {
for out_col in 0..p {
gauge[beta_base + col * p + out_col] = delta_b[[col, out_col]];
}
}
Ok(Some(gauge))
}
pub fn collapse_events(&self) -> &[CollapseEvent] {
&self.collapse_events
}
pub fn record_collapse_event(&mut self, event: CollapseEvent) {
self.collapse_events.push(event);
}
pub fn set_homotopy_eta(&mut self, eta: f64) -> Result<(), String> {
if !(eta.is_finite() && (0.0..=1.0).contains(&eta)) {
return Err(format!(
"SaeManifoldTerm::set_homotopy_eta: η must be finite in [0, 1]; got {eta}"
));
}
for atom in &mut self.atoms {
atom.homotopy_eta = eta;
}
Ok(())
}
pub fn curvature_walk_report(&self) -> Option<&CurvatureWalkReport> {
self.curvature_walk_report.as_ref()
}
pub fn set_curvature_walk_report(&mut self, report: CurvatureWalkReport) {
self.curvature_walk_report = Some(report);
}
fn reconstruction_residual(&self, target: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
let fitted = self.try_fitted()?;
if fitted.dim() != target.dim() {
return Err(format!(
"SaeManifoldTerm::reconstruction_residual: fitted {:?} != target {:?}",
fitted.dim(),
target.dim()
));
}
Ok(&fitted - &target)
}
fn curvature_basis_eta_derivatives(&self) -> Result<Vec<Array2<f64>>, String> {
let n = self.n_obs();
let mut out = Vec::with_capacity(self.k_atoms());
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
let mut d = Array2::<f64>::zeros((n, m));
if let Some(evaluator) = atom.basis_evaluator.as_ref() {
let split = evaluator.phi_eta_split(m)?;
if !split.curved_cols.is_empty() {
let coords = self.assignment.coords[atom_idx].as_matrix();
let (phi_raw, _jet) = evaluator.evaluate(coords.view())?;
for &col in &split.curved_cols {
for row in 0..n {
d[[row, col]] = phi_raw[[row, col]];
}
}
}
}
out.push(d);
}
Ok(out)
}
fn curvature_beta_gradient_eta_derivative(
&self,
target: ArrayView2<'_, f64>,
) -> Result<Array1<f64>, String> {
let n = self.n_obs();
let p = self.output_dim();
let offsets = self.beta_offsets();
let residual = self.reconstruction_residual(target)?;
let dphi_deta = self.curvature_basis_eta_derivatives()?;
let mut dfitted = Array2::<f64>::zeros((n, p));
for row in 0..n {
let a = self.assignment.try_assignments_row(row)?;
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let a_k = a[atom_idx];
if a_k == 0.0 {
continue;
}
let m = atom.basis_size();
for mu in 0..m {
let dphi = dphi_deta[atom_idx][[row, mu]];
if dphi == 0.0 {
continue;
}
let w = a_k * dphi;
for c in 0..p {
dfitted[[row, c]] += w * atom.decoder_coefficients[[mu, c]];
}
}
}
}
let mut out = Array1::<f64>::zeros(self.beta_dim());
for row in 0..n {
let a = self.assignment.try_assignments_row(row)?;
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let a_k = a[atom_idx];
if a_k == 0.0 {
continue;
}
let m = atom.basis_size();
let off = offsets[atom_idx];
for mu in 0..m {
let dphi = dphi_deta[atom_idx][[row, mu]];
let phi = atom.basis_values[[row, mu]];
for c in 0..p {
out[off + mu * p + c] +=
a_k * (dphi * residual[[row, c]] + phi * dfitted[[row, c]]);
}
}
}
}
Ok(out)
}
fn enforce_active_mass_guard(&mut self, iteration: usize) -> Result<(), String> {
let n = self.n_obs();
let k = self.k_atoms();
if n == 0 || k == 0 {
return Ok(());
}
let mut max_mass = vec![0.0_f64; k];
for row in 0..n {
let a = self
.assignment
.try_assignments_row(row)
.map_err(|e| format!("SaeManifoldTerm::enforce_active_mass_guard: {e}"))?;
for atom in 0..k {
if a[atom] > max_mass[atom] {
max_mass[atom] = a[atom];
}
}
}
for atom in 0..k {
if max_mass[atom] >= SAE_ATOM_ACTIVE_MASS_FLOOR {
continue;
}
let reseeds_used = self
.collapse_events
.iter()
.filter(|e| e.atom == atom && e.action == CollapseAction::Reseeded)
.count();
if reseeds_used < SAE_ATOM_COLLAPSE_RESEED_BUDGET {
self.reseed_collapsed_atom_logits(atom);
self.collapse_events.push(CollapseEvent {
iteration,
atom,
max_active_mass: max_mass[atom],
floor: SAE_ATOM_ACTIVE_MASS_FLOOR,
action: CollapseAction::Reseeded,
});
} else {
let already_terminal = self
.collapse_events
.iter()
.any(|e| e.atom == atom && e.action == CollapseAction::Terminal);
if !already_terminal {
self.collapse_events.push(CollapseEvent {
iteration,
atom,
max_active_mass: max_mass[atom],
floor: SAE_ATOM_ACTIVE_MASS_FLOOR,
action: CollapseAction::Terminal,
});
}
}
}
Ok(())
}
fn reseed_collapsed_atom_logits(&mut self, atom: usize) {
let n = self.n_obs();
match self.assignment.mode {
AssignmentMode::Softmax { .. } => {
for row in 0..n {
let row_max = self
.assignment
.logits
.row(row)
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max);
self.assignment.logits[[row, atom]] =
if row_max.is_finite() { row_max } else { 0.0 };
}
canonicalize_softmax_logits(&mut self.assignment.logits);
}
AssignmentMode::IBPMap { .. } => {
for row in 0..n {
self.assignment.logits[[row, atom]] = 0.0;
}
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
for row in 0..n {
self.assignment.logits[[row, atom]] = threshold + temperature;
}
}
}
}
fn apply_newton_step_impl(
&mut self,
delta_ext_coord: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
step_size: f64,
refresh_basis: bool,
) -> Result<(), String> {
if !(step_size.is_finite() && step_size > 0.0) {
return Err(format!(
"SaeManifoldTerm::apply_newton_step: step_size must be finite and positive; got {step_size}"
));
}
let n = self.n_obs();
let q = self.assignment.row_block_dim();
let k_atoms = self.k_atoms();
let assignment_dim = self.assignment.assignment_coord_dim();
let expected_delta_len = if self.last_frames_active {
self.factored_border_dim()
} else {
self.beta_dim()
};
if delta_beta.len() != expected_delta_len {
return Err(format!(
"SaeManifoldTerm::apply_newton_step: delta_beta length {} != expected {}",
delta_beta.len(),
expected_delta_len
));
}
if let Some(ref layout) = self.last_row_layout.clone() {
let total_len: usize = (0..n).map(|row| layout.row_q_active(row)).sum();
if delta_ext_coord.len() != total_len {
return Err(format!(
"SaeManifoldTerm::apply_newton_step: compact delta_ext_coord length {} != expected {}",
delta_ext_coord.len(),
total_len
));
}
let mut full_delta = vec![0.0_f64; n * q];
let mut compact_off = 0usize;
for row in 0..n {
let q_active = layout.row_q_active(row);
let compact_row: Vec<f64> = delta_ext_coord
.slice(ndarray::s![compact_off..compact_off + q_active])
.iter()
.copied()
.collect();
layout.expand_row(row, &compact_row, &mut full_delta[row * q..(row + 1) * q]);
compact_off += q_active;
}
let logit_step_cap =
SAE_ASSIGNMENT_LOGIT_STEP_CAP_TAUS * self.assignment.mode.temperature();
for row in 0..n {
let row_base = row * q;
for atom_idx in 0..assignment_dim {
self.assignment.logits[[row, atom_idx]] += (step_size
* full_delta[row_base + atom_idx])
.clamp(-logit_step_cap, logit_step_cap);
}
}
let coord_offsets = self.assignment.coord_offsets();
for atom_idx in 0..k_atoms {
let d = self.assignment.coords[atom_idx].latent_dim();
let mut delta_coord = Array1::<f64>::zeros(n * d);
for row in 0..n {
let row_base = row * q + coord_offsets[atom_idx];
for axis in 0..d {
delta_coord[row * d + axis] = step_size * full_delta[row_base + axis];
}
}
self.assignment.coords[atom_idx].retract_flat_delta(delta_coord.view());
if refresh_basis {
let coords = self.assignment.coords[atom_idx].as_matrix();
self.atoms[atom_idx].refresh_basis(coords.view())?;
}
}
} else {
if delta_ext_coord.len() != n * q {
return Err(format!(
"SaeManifoldTerm::apply_newton_step: delta_ext_coord length {} != expected {}",
delta_ext_coord.len(),
n * q
));
}
let coord_offsets = self.assignment.coord_offsets();
let logit_step_cap =
SAE_ASSIGNMENT_LOGIT_STEP_CAP_TAUS * self.assignment.mode.temperature();
for row in 0..n {
let row_base = row * q;
for atom_idx in 0..assignment_dim {
self.assignment.logits[[row, atom_idx]] += (step_size
* delta_ext_coord[row_base + atom_idx])
.clamp(-logit_step_cap, logit_step_cap);
}
}
for atom_idx in 0..k_atoms {
let d = self.assignment.coords[atom_idx].latent_dim();
let mut delta_coord = Array1::<f64>::zeros(n * d);
for row in 0..n {
let row_base = row * q + coord_offsets[atom_idx];
for axis in 0..d {
delta_coord[row * d + axis] = step_size * delta_ext_coord[row_base + axis];
}
}
self.assignment.coords[atom_idx].retract_flat_delta(delta_coord.view());
if refresh_basis {
let coords = self.assignment.coords[atom_idx].as_matrix();
self.atoms[atom_idx].refresh_basis(coords.view())?;
}
}
}
if matches!(self.assignment.mode, AssignmentMode::Softmax { .. }) {
canonicalize_softmax_logits(&mut self.assignment.logits);
}
let mut beta = self.flatten_beta();
if self.last_frames_active {
let delta_b = FrameProjection::new(self).lift_border_vec(delta_beta);
for idx in 0..beta.len() {
beta[idx] += step_size * delta_b[idx];
}
} else {
for idx in 0..beta.len() {
beta[idx] += step_size * delta_beta[idx];
}
}
self.set_flat_beta(beta.view())
}
fn solve_fixed_decoder_row_step(
h: ArrayView2<'_, f64>,
g: ArrayView1<'_, f64>,
base_ridge: f64,
) -> Result<Array1<f64>, String> {
let d = h.nrows();
if h.ncols() != d || g.len() != d {
return Err(format!(
"SaeManifoldTerm::solve_fixed_decoder_row_step: shape mismatch H={:?}, g={}",
h.dim(),
g.len()
));
}
if d == 0 {
return Ok(Array1::<f64>::zeros(0));
}
let mut ridge = base_ridge.max(SAE_MANIFOLD_ROW_RIDGE_FLOOR);
let mut last_err = String::new();
for _ in 0..SAE_MANIFOLD_ROW_RIDGE_MAX_ATTEMPTS {
let mut a = h.to_owned();
for axis in 0..d {
a[[axis, axis]] += ridge;
}
match sae_cholesky_solve_neg_gradient(a.view(), g) {
Ok(delta) => return Ok(delta),
Err(err) => {
last_err = err;
ridge *= SAE_MANIFOLD_ROW_RIDGE_GROWTH;
}
}
}
Err(format!(
"SaeManifoldTerm::solve_fixed_decoder_row_step: row Hessian did not factor after LM escalation; last error: {last_err}"
))
}
fn fixed_decoder_step_from_rows(
sys: &ArrowSchurSystem,
ridge_ext_coord: f64,
) -> Result<Array1<f64>, String> {
let total = sys.row_offsets[sys.rows.len()];
let mut delta = Array1::<f64>::zeros(total);
for (row_idx, row) in sys.rows.iter().enumerate() {
let row_delta =
Self::solve_fixed_decoder_row_step(row.htt.view(), row.gt.view(), ridge_ext_coord)?;
let start = sys.row_offsets[row_idx];
let end = sys.row_offsets[row_idx + 1];
if row_delta.len() != end - start {
return Err(format!(
"SaeManifoldTerm::fixed_decoder_step_from_rows: row {row_idx} delta len {} != row span {}",
row_delta.len(),
end - start
));
}
delta.slice_mut(s![start..end]).assign(&row_delta);
}
Ok(delta)
}
fn enrichment_visit_order(&self) -> Vec<usize> {
let n = self.n_obs();
if self.row_metric.is_none() {
return (0..n).collect();
}
let metric = match self.diagnostic_metric() {
Ok(m) => m,
Err(_) => return (0..n).collect(),
};
let measure = crate::inference::row_measure::RowMeasure::from_metric(&metric);
let drawn = measure.enrichment_order(n, n as u64);
let mut order = Vec::with_capacity(n);
let mut seen = vec![false; n];
for row in drawn {
if row < n && !seen[row] {
seen[row] = true;
order.push(row);
}
}
for (row, &was_seen) in seen.iter().enumerate() {
if !was_seen {
order.push(row);
}
}
order
}
pub fn seed_coords_by_decoder_projection(
&mut self,
target: ArrayView2<'_, f64>,
resolution: usize,
) -> Result<(), String> {
let n = self.n_obs();
let p = self.output_dim();
if target.dim() != (n, p) {
return Err(format!(
"SaeManifoldTerm::seed_coords_by_decoder_projection: target shape {:?} != ({n}, {p})",
target.dim()
));
}
let visit_order = self.enrichment_visit_order();
for atom_idx in 0..self.k_atoms() {
let d = self.atoms[atom_idx].latent_dim;
let Some(grid) = self.atoms[atom_idx]
.basis_kind
.projection_seed_grid(d, resolution)
else {
continue;
};
let Some(evaluator) = self.atoms[atom_idx].basis_evaluator.clone() else {
continue;
};
if grid.ncols() != d {
return Err(format!(
"SaeManifoldTerm::seed_coords_by_decoder_projection: atom {atom_idx} grid has {} columns but latent_dim is {d}",
grid.ncols()
));
}
let g = grid.nrows();
if g == 0 {
continue;
}
let (phi_grid, _jet) = evaluator.evaluate(grid.view())?;
if phi_grid.ncols() != self.atoms[atom_idx].basis_size() {
return Err(format!(
"SaeManifoldTerm::seed_coords_by_decoder_projection: atom {atom_idx} grid Φ has {} columns but decoder expects {}",
phi_grid.ncols(),
self.atoms[atom_idx].basis_size()
));
}
let decoded = phi_grid.dot(&self.atoms[atom_idx].decoder_coefficients);
let mut seeded = Array2::<f64>::zeros((n, d));
for &row in &visit_order {
let mut best_idx = 0usize;
let mut best_err = f64::INFINITY;
for grid_idx in 0..g {
let mut err = 0.0_f64;
for col in 0..p {
let diff = target[[row, col]] - decoded[[grid_idx, col]];
err += diff * diff;
}
if err < best_err {
best_err = err;
best_idx = grid_idx;
}
}
for axis in 0..d {
seeded[[row, axis]] = grid[[best_idx, axis]];
}
}
let flat = Array1::from_iter(seeded.iter().copied());
self.assignment.coords[atom_idx].set_flat(flat.view());
let coords = self.assignment.coords[atom_idx].as_matrix();
self.atoms[atom_idx].refresh_basis(coords.view())?;
}
Ok(())
}
pub fn run_fixed_decoder_arrow_schur(
&mut self,
target: ArrayView2<'_, f64>,
rho: &mut SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
max_iter: usize,
step_size: f64,
ridge_ext_coord: f64,
) -> Result<SaeManifoldLoss, String> {
if !(step_size.is_finite() && step_size > 0.0) {
return Err(format!(
"SaeManifoldTerm::run_fixed_decoder_arrow_schur: step_size must be finite and positive; got {step_size}"
));
}
if max_iter < 1 {
return Err(
"SaeManifoldTerm::run_fixed_decoder_arrow_schur: max_iter must be positive".into(),
);
}
let beta_zero = Array1::<f64>::zeros(self.beta_dim());
let mut last_loss = self.loss(target, rho)?;
for _ in 0..max_iter {
self.advance_temperature_schedule()?;
let pre_step_loss = self.loss(target, rho)?;
let sys = self
.assemble_arrow_schur(target, rho, analytic_penalties)
.map_err(|err| format!("SaeManifoldTerm::run_fixed_decoder_arrow_schur: {err}"))?;
let pre_step_total =
self.penalized_objective_total(target, rho, analytic_penalties, 1.0)?;
let delta_ext_coord = Self::fixed_decoder_step_from_rows(&sys, ridge_ext_coord)?;
let directional_decrease = sae_manifold_newton_directional_decrease(
&sys,
delta_ext_coord.view(),
beta_zero.view(),
);
let grad_norm_sq: f64 = sys
.rows
.iter()
.flat_map(|row| row.gt.iter())
.map(|&v| v * v)
.sum();
let step_norm_sq: f64 = delta_ext_coord.iter().map(|&v| v * v).sum();
let directional_decrease_floor = SAE_MANIFOLD_DIRECTIONAL_DECREASE_REL_FLOOR
* grad_norm_sq.sqrt()
* step_norm_sq.sqrt();
let snapshot = self.snapshot_mutable_state();
if !(pre_step_total.is_finite()
&& directional_decrease.is_finite()
&& directional_decrease > 0.0
&& directional_decrease > directional_decrease_floor)
{
self.restore_mutable_state(&snapshot);
last_loss = pre_step_loss;
break;
}
let mut trial_step_size = step_size;
let mut accepted_loss: Option<SaeManifoldLoss> = None;
for halving in 0..=SAE_MANIFOLD_MAX_LINESEARCH_HALVINGS {
if halving > 0 {
self.restore_mutable_state(&snapshot);
}
let trial_result = self
.apply_newton_step(delta_ext_coord.view(), beta_zero.view(), trial_step_size)
.and_then(|()| {
self.penalized_objective_total(target, rho, analytic_penalties, 1.0)
});
if let Ok(post_step_total) = trial_result {
let armijo_bound = pre_step_total
- SAE_MANIFOLD_ARMIJO_C1 * trial_step_size * directional_decrease;
if post_step_total.is_finite() && post_step_total <= armijo_bound {
accepted_loss = Some(self.loss(target, rho)?);
break;
}
}
trial_step_size *= 0.5;
}
match accepted_loss {
Some(loss) => last_loss = loss,
None => {
self.restore_mutable_state(&snapshot);
last_loss = pre_step_loss;
break;
}
}
}
Ok(last_loss)
}
pub fn run_joint_fit_arrow_schur(
&mut self,
target: ArrayView2<'_, f64>,
rho: &mut SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
max_iter: usize,
step_size: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
) -> Result<SaeManifoldLoss, String> {
if !(step_size.is_finite() && step_size > 0.0) {
return Err(format!(
"SaeManifoldTerm::run_joint_fit_arrow_schur: step_size must be finite and positive; got {step_size}"
));
}
self.refresh_basis_from_current_coords()
.map_err(|err| format!("SaeManifoldTerm::run_joint_fit_arrow_schur: {err}"))?;
self.ensure_decoder_frames_active_for_current_decoder()
.map_err(|err| format!("SaeManifoldTerm::run_joint_fit_arrow_schur: {err}"))?;
self.collapse_events.clear();
self.enforce_active_mass_guard(0)?;
{
let mut grams = self.empty_decoder_gram_accumulator();
self.accumulate_decoder_gram(&mut grams);
self.finalize_decoder_identifiability_audit(&grams, self.n_obs())?;
}
for outer_iteration in 0..max_iter {
self.advance_temperature_schedule()?;
let sys = self
.assemble_arrow_schur(target, rho, analytic_penalties)
.map_err(|err| format!("SaeManifoldTerm::run_joint_fit_arrow_schur: {err}"))?;
let (delta_ext_coord, delta_beta, _diag) = sys
.solve_with_lm_escalation(ridge_ext_coord, ridge_beta)
.map_err(|err| format!("SaeManifoldTerm::run_joint_fit_arrow_schur: {err}"))?;
let directional_decrease = sae_manifold_newton_directional_decrease(
&sys,
delta_ext_coord.view(),
delta_beta.view(),
);
let mut grad_norm_sq = 0.0;
for (row_idx, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[row_idx];
for axis in 0..di {
grad_norm_sq += row.gt[axis] * row.gt[axis];
}
}
for idx in 0..sys.k {
grad_norm_sq += sys.gb[idx] * sys.gb[idx];
}
let mut step_norm_sq = 0.0;
for &v in delta_ext_coord.iter() {
step_norm_sq += v * v;
}
for &v in delta_beta.iter() {
step_norm_sq += v * v;
}
let directional_decrease_floor = SAE_MANIFOLD_DIRECTIONAL_DECREASE_REL_FLOOR
* grad_norm_sq.sqrt()
* step_norm_sq.sqrt();
let snapshot = self.snapshot_mutable_state();
let pre_step_total =
self.penalized_objective_total(target, rho, analytic_penalties, 1.0)?;
if !(pre_step_total.is_finite()
&& directional_decrease.is_finite()
&& directional_decrease > 0.0
&& directional_decrease > directional_decrease_floor)
{
self.restore_mutable_state(&snapshot);
break;
}
let mut trial_step_size = step_size;
let mut accepted = false;
for halving in 0..=SAE_MANIFOLD_MAX_LINESEARCH_HALVINGS {
if halving > 0 {
self.restore_mutable_state(&snapshot);
}
let trial_result = self
.apply_newton_step(delta_ext_coord.view(), delta_beta.view(), trial_step_size)
.and_then(|()| {
self.penalized_objective_total(target, rho, analytic_penalties, 1.0)
});
if let Ok(post_step_total) = trial_result {
let armijo_bound = pre_step_total
- SAE_MANIFOLD_ARMIJO_C1 * trial_step_size * directional_decrease;
if post_step_total.is_finite() && post_step_total <= armijo_bound {
accepted = true;
break;
}
}
trial_step_size *= 0.5;
}
if !accepted {
self.restore_mutable_state(&snapshot);
let correction = ArrowProximalCorrectionOptions {
initial_ridge: ridge_ext_coord
.max(ridge_beta)
.max(SAE_MANIFOLD_ROW_RIDGE_FLOOR),
armijo_c1: SAE_MANIFOLD_ARMIJO_C1,
..ArrowProximalCorrectionOptions::default()
};
let accepted_step = match solve_arrow_newton_step_with_proximal_correction(
&sys,
ridge_ext_coord,
ridge_beta,
pre_step_total,
&ArrowSolveOptions::automatic(sys.k),
&correction,
|trial_delta_t, trial_delta_beta| {
self.restore_mutable_state(&snapshot);
self.apply_newton_step(trial_delta_t, trial_delta_beta, 1.0)
.and_then(|()| {
self.penalized_objective_total(target, rho, analytic_penalties, 1.0)
})
.unwrap_or(f64::INFINITY)
},
) {
Ok(step) => step,
Err(_err) => {
self.restore_mutable_state(&snapshot);
break;
}
};
if !(accepted_step.trial_objective_value.is_finite()
&& accepted_step.trial_objective_value < pre_step_total)
{
self.restore_mutable_state(&snapshot);
break;
}
}
self.canonicalize_affine_gauge_after_accept()?;
self.enforce_active_mass_guard(outer_iteration)?;
if self.frames_active() {
self.refresh_active_frames_from_data(target)
.map_err(|err| format!("SaeManifoldTerm::run_joint_fit_arrow_schur: {err}"))?;
}
}
self.loss(target, rho)
}
fn empty_decoder_gram_accumulator(&self) -> Vec<Array2<f64>> {
self.atoms
.iter()
.map(|atom| {
let m = atom.basis_size();
Array2::<f64>::zeros((m, m))
})
.collect()
}
fn accumulate_decoder_gram(&self, grams: &mut [Array2<f64>]) {
let n = self.n_obs();
let assignments = self.assignment.assignments();
let weights: Vec<Array1<f64>> = (0..self.atoms.len())
.map(|atom_idx| {
let col = assignments.column(atom_idx);
col.mapv(|a| a * a)
})
.collect();
let cpu_one = |atom_idx: usize, gram: &mut Array2<f64>| {
let atom = &self.atoms[atom_idx];
let m = atom.basis_size();
let assign_col = assignments.column(atom_idx);
let mut weighted = vec![0.0_f64; m];
for row in 0..n {
let a_k = assign_col[row];
if a_k == 0.0 {
continue;
}
for col in 0..m {
weighted[col] = a_k * atom.basis_values[[row, col]];
}
for i in 0..m {
let wi = weighted[i];
if wi == 0.0 {
continue;
}
for j in 0..m {
gram[[i, j]] += wi * weighted[j];
}
}
}
};
let rt = crate::gpu::runtime::GpuRuntime::global();
match rt {
None => {
for atom_idx in 0..self.atoms.len() {
if self.atoms[atom_idx].basis_size() == 0 {
continue;
}
cpu_one(atom_idx, &mut grams[atom_idx]);
}
}
Some(rt) => {
let mut items: Vec<usize> = (0..self.atoms.len())
.filter(|&i| self.atoms[i].basis_size() > 0)
.collect();
let device_grams: std::sync::Mutex<Vec<(usize, Array2<f64>)>> =
std::sync::Mutex::new(Vec::with_capacity(items.len()));
let declined: std::sync::Mutex<Vec<usize>> = std::sync::Mutex::new(Vec::new());
let atoms_ref = &self.atoms;
let weights_ref = &weights;
let ok = crate::gpu::pool::scatter_batched(rt, &mut items, |_ordinal, slice| {
for &atom_idx in slice.iter() {
let phi = atoms_ref[atom_idx].basis_values.view();
let w = weights_ref[atom_idx].view();
match crate::gpu::linalg::try_fast_xt_diag_x(phi, w) {
Some(g) => device_grams
.lock()
.expect("device_grams mutex poisoned")
.push((atom_idx, g)),
None => declined
.lock()
.expect("declined mutex poisoned")
.push(atom_idx),
}
}
Some(())
});
match ok {
Some(()) => {
for (atom_idx, g) in device_grams
.into_inner()
.expect("device_grams mutex poisoned")
{
grams[atom_idx] += &g;
}
for atom_idx in declined.into_inner().expect("declined mutex poisoned") {
cpu_one(atom_idx, &mut grams[atom_idx]);
}
}
None => {
for atom_idx in 0..self.atoms.len() {
if self.atoms[atom_idx].basis_size() == 0 {
continue;
}
cpu_one(atom_idx, &mut grams[atom_idx]);
}
}
}
}
}
}
fn finalize_decoder_identifiability_audit(
&self,
grams: &[Array2<f64>],
n_total: usize,
) -> Result<(), String> {
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let m = atom.basis_size();
if m == 0 {
continue;
}
let rank =
crate::solver::identifiability_audit::rank_of_gram(&grams[atom_idx], n_total)
.map_err(|e| {
format!(
"SaeManifoldTerm: pre-fit decoder audit (atom '{}'): \
Gram eigendecomposition failed: {e}",
atom.name,
)
})?;
if rank < m {
let dropped = m - rank;
if rank == 0 {
return Err(format!(
"SaeManifoldTerm: pre-fit identifiability audit: decoder atom '{}' has \
rank-0 weighted design (n={n_total}, M_k={m}); all assignment weights \
vanish or the basis is degenerate, so the Arrow-Schur Newton system for \
this block is singular",
atom.name,
));
}
log::info!(
"[SAE-AUDIT] decoder atom '{}' weighted design is rank-deficient \
(rank={rank}/{m}, {dropped} weakly-identified column(s), n={n_total}); the \
Arrow-Schur ridge will regularise the deficient directions",
atom.name,
);
}
}
Ok(())
}
pub fn materialize_chunk(
&self,
chunk_logits: Array2<f64>,
chunk_coords: Vec<Array2<f64>>,
) -> Result<SaeManifoldTerm, String> {
let k_atoms = self.k_atoms();
if chunk_logits.ncols() != k_atoms {
return Err(format!(
"SaeManifoldTerm::materialize_chunk: chunk_logits has {} cols but K={k_atoms}",
chunk_logits.ncols()
));
}
if chunk_coords.len() != k_atoms {
return Err(format!(
"SaeManifoldTerm::materialize_chunk: chunk_coords has {} atoms but K={k_atoms}",
chunk_coords.len()
));
}
let n_chunk = chunk_logits.nrows();
let mut atoms = Vec::with_capacity(k_atoms);
for (atom_idx, atom) in self.atoms.iter().enumerate() {
let coords = &chunk_coords[atom_idx];
if coords.nrows() != n_chunk || coords.ncols() != atom.latent_dim {
return Err(format!(
"SaeManifoldTerm::materialize_chunk: atom {atom_idx} coords shape {:?} != ({n_chunk}, {})",
coords.dim(),
atom.latent_dim
));
}
let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
format!(
"SaeManifoldTerm::materialize_chunk: atom '{}' has no basis evaluator; a \
streaming fit must re-evaluate Φ(t) at each chunk's coordinates",
atom.name
)
})?;
let (phi, jet) = evaluator.evaluate(coords.view())?;
let m = atom.basis_size();
if phi.dim() != (n_chunk, m) {
return Err(format!(
"SaeManifoldTerm::materialize_chunk: atom '{}' evaluator returned Φ {:?}, expected ({n_chunk}, {m})",
atom.name,
phi.dim()
));
}
if jet.dim() != (n_chunk, m, atom.latent_dim) {
return Err(format!(
"SaeManifoldTerm::materialize_chunk: atom '{}' evaluator returned jet {:?}, expected ({n_chunk}, {m}, {})",
atom.name,
jet.dim(),
atom.latent_dim
));
}
let mut chunk_atom = SaeManifoldAtom::new(
atom.name.clone(),
atom.basis_kind.clone(),
atom.latent_dim,
phi,
jet,
atom.decoder_coefficients.clone(),
atom.smooth_penalty_raw.clone(),
)?;
chunk_atom.basis_evaluator = atom.basis_evaluator.clone();
chunk_atom.basis_second_jet = atom.basis_second_jet.clone();
chunk_atom.decoder_frame = atom.decoder_frame.clone();
atoms.push(chunk_atom);
}
let coord_values: Vec<LatentCoordValues> = chunk_coords
.iter()
.zip(self.assignment.coords.iter())
.map(|(c, src)| {
LatentCoordValues::from_matrix_with_manifold(
c.view(),
LatentIdMode::None,
src.manifold().clone(),
)
})
.collect();
let assignment =
SaeAssignment::with_mode(chunk_logits, coord_values, self.assignment.mode)?;
let mut term = SaeManifoldTerm::new(atoms, assignment)?;
term.temperature_schedule = self.temperature_schedule.clone();
Ok(term)
}
pub fn run_joint_fit_arrow_schur_streaming<F>(
&mut self,
n_total: usize,
chunk_size: usize,
rho: &mut SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
max_iter: usize,
step_size: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
mut chunk_init: F,
) -> Result<SaeManifoldLoss, String>
where
F: FnMut(usize, usize) -> Result<(Array2<f64>, Vec<Array2<f64>>, Array2<f64>), String>,
{
if !(step_size.is_finite() && step_size > 0.0) {
return Err(format!(
"SaeManifoldTerm::run_joint_fit_arrow_schur_streaming: step_size must be finite and positive; got {step_size}"
));
}
if chunk_size == 0 {
return Err(
"SaeManifoldTerm::run_joint_fit_arrow_schur_streaming: chunk_size must be positive"
.to_string(),
);
}
if n_total == 0 {
return Err(
"SaeManifoldTerm::run_joint_fit_arrow_schur_streaming: n_total must be positive"
.to_string(),
);
}
self.ensure_decoder_frames_active_for_current_decoder()
.map_err(|err| {
format!("SaeManifoldTerm::run_joint_fit_arrow_schur_streaming: {err}")
})?;
let frames_engaged = self.frames_active();
let border_dim = if frames_engaged {
self.factored_border_dim()
} else {
self.beta_dim()
};
{
let mut grams = self.empty_decoder_gram_accumulator();
let mut start = 0usize;
while start < n_total {
let end = (start + chunk_size).min(n_total);
let (logits, coords, _z_chunk) = chunk_init(start, end)?;
let chunk = self.materialize_chunk(logits, coords)?;
chunk.accumulate_decoder_gram(&mut grams);
start = end;
}
self.finalize_decoder_identifiability_audit(&grams, n_total)?;
}
let mut last_loss = SaeManifoldLoss {
data_fit: 0.0,
assignment_sparsity: 0.0,
smoothness: 0.0,
ard: 0.0,
};
for _ in 0..max_iter {
self.advance_temperature_schedule()?;
let options = ArrowSolveOptions::automatic(border_dim);
let mut s_acc = Array2::<f64>::zeros((border_dim, border_dim));
let mut rhs_acc = Array1::<f64>::zeros(border_dim);
let mut gb_acc = Array1::<f64>::zeros(border_dim);
let mut pre_step_total = 0.0_f64;
let mut chunk_ranges: Vec<(usize, usize)> = Vec::new();
let mut start = 0usize;
while start < n_total {
let end = (start + chunk_size).min(n_total);
let n_chunk = end - start;
let penalty_scale = n_chunk as f64 / n_total as f64;
let (logits, coords, z_chunk) = chunk_init(start, end)?;
if z_chunk.dim() != (n_chunk, self.output_dim()) {
return Err(format!(
"SaeManifoldTerm::run_joint_fit_arrow_schur_streaming: chunk [{start}, {end}) \
Z slice shape {:?} != ({n_chunk}, {})",
z_chunk.dim(),
self.output_dim()
));
}
let mut chunk = self.materialize_chunk(logits, coords)?;
if let Some(w) = self.row_loss_weights.as_deref() {
chunk.row_loss_weights = Some(w[start..end].to_vec());
}
chunk_ranges.push((start, end));
pre_step_total += chunk.penalized_objective_total(
z_chunk.view(),
rho,
analytic_penalties,
penalty_scale,
)?;
let sys = chunk
.assemble_arrow_schur_scaled(
z_chunk.view(),
rho,
analytic_penalties,
penalty_scale,
)
.map_err(|err| {
format!("SaeManifoldTerm::run_joint_fit_arrow_schur_streaming: {err}")
})?;
for j in 0..border_dim {
gb_acc[j] += sys.gb[j];
}
Self::accumulate_chunk_reduced_schur(
&sys,
ridge_ext_coord,
&options,
&mut s_acc,
&mut rhs_acc,
)?;
start = end;
}
for j in 0..border_dim {
s_acc[[j, j]] += ridge_beta;
rhs_acc[j] -= gb_acc[j];
}
let delta_beta =
solve_streaming_reduced_beta(&s_acc, &rhs_acc, &options).map_err(|err| {
format!("SaeManifoldTerm::run_joint_fit_arrow_schur_streaming: {err}")
})?;
let beta0 = self.flatten_beta();
let mut directional_decrease = 0.0_f64;
for j in 0..border_dim {
directional_decrease += rhs_acc[j] * delta_beta[j];
}
if !(pre_step_total.is_finite()
&& directional_decrease.is_finite()
&& directional_decrease > 0.0)
{
last_loss = self.streaming_loss(&chunk_ranges, rho, n_total, &mut chunk_init)?;
break;
}
let delta_b: Array1<f64> = if frames_engaged {
FrameProjection::new(self).lift_border_vec(delta_beta.view())
} else {
delta_beta.clone()
};
let mut trial_step = step_size;
let mut accepted_loss: Option<SaeManifoldLoss> = None;
for _ in 0..=SAE_MANIFOLD_MAX_LINESEARCH_HALVINGS {
let mut trial_beta = beta0.clone();
for j in 0..self.beta_dim() {
trial_beta[j] += trial_step * delta_b[j];
}
self.set_flat_beta(trial_beta.view())?;
let (trial_loss, trial_total) = self.streaming_loss_and_penalized_objective_total(
&chunk_ranges,
rho,
analytic_penalties,
n_total,
&mut chunk_init,
)?;
let armijo_bound =
pre_step_total - SAE_MANIFOLD_ARMIJO_C1 * trial_step * directional_decrease;
if trial_total.is_finite() && trial_total <= armijo_bound {
accepted_loss = Some(trial_loss);
break;
}
trial_step *= 0.5;
}
match accepted_loss {
Some(loss) => {
last_loss = loss;
}
None => {
self.set_flat_beta(beta0.view())?;
last_loss =
self.streaming_loss(&chunk_ranges, rho, n_total, &mut chunk_init)?;
break;
}
}
}
Ok(last_loss)
}
pub fn fit_streaming_in_memory(
&mut self,
target: ArrayView2<'_, f64>,
rho: &mut SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
max_iter: usize,
step_size: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
) -> Result<SaeManifoldLoss, String> {
let n_total = self.n_obs();
if target.dim() != (n_total, self.output_dim()) {
return Err(format!(
"SaeManifoldTerm::fit_streaming_in_memory: target must be ({}, {}); got {:?}",
n_total,
self.output_dim(),
target.dim()
));
}
let chunk_size = self.streaming_plan().chunk_size.min(n_total.max(1));
let seed_logits = self.assignment.logits.clone();
let seed_coords: Vec<Array2<f64>> = self
.assignment
.coords
.iter()
.map(|coord| coord.as_matrix().to_owned())
.collect();
let chunk_init = move |start: usize, end: usize| {
let logits = seed_logits.slice(s![start..end, ..]).to_owned();
let coords: Vec<Array2<f64>> = seed_coords
.iter()
.map(|coord| coord.slice(s![start..end, ..]).to_owned())
.collect();
let z_chunk = target.slice(s![start..end, ..]).to_owned();
Ok((logits, coords, z_chunk))
};
self.run_joint_fit_arrow_schur_streaming(
n_total,
chunk_size,
rho,
analytic_penalties,
max_iter,
step_size,
ridge_ext_coord,
ridge_beta,
chunk_init,
)
}
fn accumulate_chunk_reduced_schur(
sys: &ArrowSchurSystem,
ridge_ext_coord: f64,
options: &ArrowSolveOptions,
s_acc: &mut Array2<f64>,
rhs_acc: &mut Array1<f64>,
) -> Result<(), String> {
let k = sys.k;
let chunk_n = sys.rows.len();
let mut streaming = StreamingArrowSchur::from_system(sys, chunk_n.max(1));
streaming
.reset_accumulator(0.0)
.map_err(|e| e.to_string())?;
streaming
.accumulate_chunk(0, chunk_n, ridge_ext_coord, options.mode)
.map_err(|e| e.to_string())?;
let (contrib_s, contrib_rhs) = streaming.take_accumulators();
for i in 0..k {
rhs_acc[i] += contrib_rhs[i];
for j in 0..k {
s_acc[[i, j]] += contrib_s[[i, j]];
}
}
Ok(())
}
fn streaming_loss<F>(
&self,
chunk_ranges: &[(usize, usize)],
rho: &SaeManifoldRho,
n_total: usize,
chunk_init: &mut F,
) -> Result<SaeManifoldLoss, String>
where
F: FnMut(usize, usize) -> Result<(Array2<f64>, Vec<Array2<f64>>, Array2<f64>), String>,
{
let mut data_fit = 0.0_f64;
let mut assignment_sparsity = 0.0_f64;
let mut smoothness = 0.0_f64;
let mut ard = 0.0_f64;
for &(start, end) in chunk_ranges {
let n_chunk = end - start;
let penalty_scale = n_chunk as f64 / n_total as f64;
let (logits, coords, z_chunk) = chunk_init(start, end)?;
let mut chunk = self.materialize_chunk(logits, coords)?;
if let Some(w) = self.row_loss_weights.as_deref() {
chunk.row_loss_weights = Some(w[start..end].to_vec());
}
let loss = chunk.loss_scaled(z_chunk.view(), rho, penalty_scale)?;
data_fit += loss.data_fit;
assignment_sparsity += loss.assignment_sparsity;
smoothness += loss.smoothness;
ard += loss.ard;
}
Ok(SaeManifoldLoss {
data_fit,
assignment_sparsity,
smoothness,
ard,
})
}
fn streaming_loss_and_penalized_objective_total<F>(
&self,
chunk_ranges: &[(usize, usize)],
rho: &SaeManifoldRho,
analytic_penalties: Option<&AnalyticPenaltyRegistry>,
n_total: usize,
chunk_init: &mut F,
) -> Result<(SaeManifoldLoss, f64), String>
where
F: FnMut(usize, usize) -> Result<(Array2<f64>, Vec<Array2<f64>>, Array2<f64>), String>,
{
let mut data_fit = 0.0_f64;
let mut assignment_sparsity = 0.0_f64;
let mut smoothness = 0.0_f64;
let mut ard = 0.0_f64;
let mut total = 0.0_f64;
for &(start, end) in chunk_ranges {
let n_chunk = end - start;
let penalty_scale = n_chunk as f64 / n_total as f64;
let (logits, coords, z_chunk) = chunk_init(start, end)?;
let mut chunk = self.materialize_chunk(logits, coords)?;
if let Some(w) = self.row_loss_weights.as_deref() {
chunk.row_loss_weights = Some(w[start..end].to_vec());
}
let loss = chunk.loss_scaled(z_chunk.view(), rho, penalty_scale)?;
data_fit += loss.data_fit;
assignment_sparsity += loss.assignment_sparsity;
smoothness += loss.smoothness;
ard += loss.ard;
total += chunk.penalized_objective_total(
z_chunk.view(),
rho,
analytic_penalties,
penalty_scale,
)?;
}
Ok((
SaeManifoldLoss {
data_fit,
assignment_sparsity,
smoothness,
ard,
},
total,
))
}
}
pub struct SaeManifoldOuterObjective {
term: SaeManifoldTerm,
baseline_term: SaeManifoldTerm,
target: Array2<f64>,
registry: Option<AnalyticPenaltyRegistry>,
current_rho: SaeManifoldRho,
baseline_rho: SaeManifoldRho,
inner_max_iter: usize,
learning_rate: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
last_loss: Option<SaeManifoldLoss>,
seeded_beta: Option<Array1<f64>>,
}
impl SaeManifoldOuterObjective {
pub fn new(
term: SaeManifoldTerm,
target: Array2<f64>,
registry: Option<AnalyticPenaltyRegistry>,
init_rho: SaeManifoldRho,
inner_max_iter: usize,
learning_rate: f64,
ridge_ext_coord: f64,
ridge_beta: f64,
) -> Self {
let baseline_term = term.clone();
let baseline_rho = init_rho.clone();
Self {
term,
baseline_term,
target,
registry,
current_rho: init_rho,
baseline_rho,
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
last_loss: None,
seeded_beta: None,
}
}
pub fn into_fitted(self) -> (SaeManifoldTerm, SaeManifoldRho, SaeManifoldLoss) {
let Self {
term,
mut baseline_term,
target,
registry,
current_rho,
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
last_loss,
..
} = self;
let loss = last_loss.unwrap_or_else(|| SaeManifoldLoss {
data_fit: 0.0,
assignment_sparsity: 0.0,
smoothness: 0.0,
ard: 0.0,
});
let settled_objective =
term.penalized_objective_total(target.view(), ¤t_rho, registry.as_ref(), 1.0);
let mut rho_seed = current_rho.clone();
let seed_solve = baseline_term.run_joint_fit_arrow_schur(
target.view(),
&mut rho_seed,
registry.as_ref(),
inner_max_iter,
learning_rate,
ridge_ext_coord,
ridge_beta,
);
if let (Ok(settled_total), Ok(seed_loss)) = (settled_objective, seed_solve) {
let seed_total = baseline_term.penalized_objective_total(
target.view(),
¤t_rho,
registry.as_ref(),
1.0,
);
if let Ok(seed_total) = seed_total {
if seed_total.is_finite() && seed_total < settled_total {
return (baseline_term, current_rho, seed_loss);
}
}
}
(term, current_rho, loss)
}
pub fn optimality_certificate(&mut self) -> Result<CriterionCertificate, String> {
let rho_hat_flat = self.current_rho.to_flat();
let dir = deterministic_probe_direction(rho_hat_flat.view());
let h = probe_step(rho_hat_flat.view());
let rho_hat = self.current_rho.clone();
let (_v_hat, loss_hat, cache) = self.term.reml_criterion_with_cache(
self.target.view(),
&rho_hat,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
)?;
let components = self
.term
.analytic_outer_rho_gradient_components(&rho_hat, &loss_hat, &cache)?;
let grad = components.gradient_with_available_correction();
let grad_norm = grad.iter().map(|g| g * g).sum::<f64>().sqrt();
let analytic_directional: f64 = grad.iter().zip(dir.iter()).map(|(g, d)| g * d).sum();
let mut probe_term = self.baseline_term.clone();
let value_at = |term: &mut SaeManifoldTerm, mult: f64| -> Result<f64, String> {
let flat: Array1<f64> =
Array1::from_shape_fn(rho_hat_flat.len(), |i| rho_hat_flat[i] + mult * h * dir[i]);
let rho = self.baseline_rho.from_flat(flat.view());
let (cost, _loss) = term.reml_criterion(
self.target.view(),
&rho,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
)?;
Ok(cost)
};
let plus_h = value_at(&mut probe_term, 1.0)?;
let minus_h = value_at(&mut probe_term, -1.0)?;
let plus_2h = value_at(&mut probe_term, 2.0)?;
let minus_2h = value_at(&mut probe_term, -2.0)?;
let well_posed = plus_h.is_finite()
&& minus_h.is_finite()
&& plus_2h.is_finite()
&& minus_2h.is_finite();
let samples = DirectionalSamples {
plus_h,
minus_h,
plus_2h,
minus_2h,
step: h,
grad_norm,
analytic_directional,
well_posed,
};
Ok(certificate_from_samples(&samples))
}
pub fn curvature_walk_report(&self) -> Option<&CurvatureWalkReport> {
self.term.curvature_walk_report()
}
pub fn decoder_shape_uncertainty(&mut self) -> Result<SaeShapeUncertainty, String> {
let rho = self.current_rho.clone();
let (_cost, loss, cache) = self.term.reml_criterion_with_cache(
self.target.view(),
&rho,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
)?;
let dispersion = self.term.reconstruction_dispersion(&loss, &cache, &rho)?;
self.term.assemble_shape_uncertainty(&cache, dispersion)
}
pub fn run_curvature_homotopy_entry(&mut self) -> Result<bool, String> {
let rho = self.baseline_rho.clone();
let isometry_targets = self
.registry
.as_ref()
.map(AnalyticPenaltyRegistry::isometry_scalar_weights)
.unwrap_or_default();
self.set_isometry_homotopy_weight(0.0, &isometry_targets);
let anchor = match linear_span_anchor(&self.term, self.target.view()) {
Ok(anchor) => anchor,
Err(err) => {
log::info!(
"[#1007] curvature anchor degenerate ({err}); deferring to seed cascade"
);
self.set_isometry_homotopy_weight(1.0, &isometry_targets);
return Ok(false);
}
};
let anchor_residual_norm_sq = anchor.residual_norm_sq;
let (_loss0, mut last_cache) = match self.solve_at_eta(&rho, 0.0, &isometry_targets) {
Ok(pair) => pair,
Err(err) => {
log::info!(
"[#1007] curvature anchor solve failed at η=0 ({err}); deferring to cascade"
);
self.term.set_homotopy_eta(1.0).ok();
self.set_isometry_homotopy_weight(1.0, &isometry_targets);
return Ok(false);
}
};
let mut eta = 0.0_f64;
let mut eta_step = CURVATURE_WALK_INITIAL_ETA_STEP;
let mut eta_steps = 0usize;
let mut step_halvings = 0usize;
let mut total_correctors = 0usize;
let mut bifurcation: Option<CurvatureBifurcation> = None;
'walk: while eta < 1.0 {
let eta_next = (eta + eta_step).min(1.0);
let d_eta = eta_next - eta;
if let Ok(dg_beta) = self
.term
.curvature_beta_gradient_eta_derivative(self.target.view())
&& dg_beta.len() == last_cache.k
{
let w_t = Array1::<f64>::zeros(last_cache.delta_t_len());
if let Ok((_u_t, u_beta)) =
last_cache.full_inverse_apply(w_t.view(), dg_beta.view())
{
let mut beta = self.term.flatten_beta();
if beta.len() == u_beta.len() {
for (b, u) in beta.iter_mut().zip(u_beta.iter()) {
*b -= u * d_eta;
}
if beta.iter().all(|v| v.is_finite()) {
self.term.set_flat_beta(beta.view()).ok();
}
}
}
}
let cache = match self.solve_at_eta(&rho, eta_next, &isometry_targets) {
Ok((_loss, cache)) => cache,
Err(err) => {
if eta_step <= CURVATURE_WALK_MIN_ETA_STEP {
log::info!(
"[#1007] curvature corrector failed at η={eta_next:.4} at the minimum \
η-step ({err}); recording branch bifurcation"
);
bifurcation = Some(CurvatureBifurcation {
eta: eta_next,
min_pivot: 0.0,
});
break 'walk;
}
eta_step *= 0.5;
step_halvings += 1;
self.term.set_homotopy_eta(eta).ok();
self.set_isometry_homotopy_weight(eta, &isometry_targets);
continue 'walk;
}
};
total_correctors += 1;
let pivot = arrow_factor_min_pivot(&cache).min_pivot.unwrap_or(0.0);
let diag_scale = arrow_factor_max_pivot(&cache).unwrap_or(1.0);
let floor = f64::EPSILON.sqrt() * diag_scale.max(1.0);
if !(pivot.is_finite() && pivot >= floor) {
if eta_step > CURVATURE_WALK_MIN_ETA_STEP {
eta_step *= 0.5;
step_halvings += 1;
self.term.set_homotopy_eta(eta).ok();
self.set_isometry_homotopy_weight(eta, &isometry_targets);
continue 'walk;
}
log::info!(
"[#1007] curvature branch bifurcation at η={eta_next:.4}: min pivot \
{pivot:.3e} < floor {floor:.3e}; deferring to seed cascade"
);
bifurcation = Some(CurvatureBifurcation {
eta: eta_next,
min_pivot: pivot,
});
break 'walk;
}
eta = eta_next;
last_cache = cache;
eta_steps += 1;
eta_step = (eta_step * 2.0).min(CURVATURE_WALK_INITIAL_ETA_STEP);
if total_correctors >= CURVATURE_WALK_MAX_CORRECTORS && eta < 1.0 {
log::info!(
"[#1007] curvature walk hit its corrector budget at η={eta:.4}; deferring to \
seed cascade"
);
bifurcation = Some(CurvatureBifurcation {
eta,
min_pivot: pivot,
});
break 'walk;
}
}
let arrived = bifurcation.is_none() && eta >= 1.0;
if !arrived {
self.term.set_homotopy_eta(1.0).ok();
}
self.set_isometry_homotopy_weight(1.0, &isometry_targets);
let collapse_events = self.term.collapse_events().len();
self.term.set_curvature_walk_report(CurvatureWalkReport {
arrived,
anchor_residual_norm_sq,
bifurcation,
eta_steps,
step_halvings,
collapse_events,
reseeds: 0,
});
Ok(arrived)
}
fn solve_at_eta(
&mut self,
rho: &SaeManifoldRho,
eta: f64,
isometry_targets: &[f64],
) -> Result<(SaeManifoldLoss, ArrowFactorCache), String> {
self.term.set_homotopy_eta(eta)?;
self.set_isometry_homotopy_weight(eta, isometry_targets);
let (_cost, loss, cache) = self.term.reml_criterion_with_cache(
self.target.view(),
rho,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
)?;
self.last_loss = Some(loss.clone());
Ok((loss, cache))
}
fn set_isometry_homotopy_weight(&mut self, eta: f64, targets: &[f64]) {
if targets.is_empty() {
return;
}
if let Some(registry) = self.registry.as_mut() {
let eta = eta.clamp(0.0, 1.0);
let weights: Vec<f64> = targets.iter().map(|target| eta * target).collect();
registry.set_isometry_scalar_weights(&weights);
}
}
fn evaluate(&mut self, rho_flat: ArrayView1<'_, f64>) -> Result<(f64, Array1<f64>), String> {
let rho = self.baseline_rho.from_flat(rho_flat);
if let Some(beta) = self.seeded_beta.take() {
if beta.len() == self.term.beta_dim() {
self.term.set_flat_beta(beta.view())?;
}
}
let (cost, loss) = self.term.reml_criterion(
self.target.view(),
&rho,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
)?;
self.current_rho = rho;
self.last_loss = Some(loss);
let beta_hat = self.term.flatten_beta();
Ok((cost, beta_hat))
}
fn efs_step(&mut self, rho_flat: ArrayView1<'_, f64>) -> Result<EfsEval, String> {
let rho = self.baseline_rho.from_flat(rho_flat);
if let Some(beta) = self.seeded_beta.take()
&& beta.len() == self.term.beta_dim()
{
self.term.set_flat_beta(beta.view())?;
}
let (cost, loss, cache) = self.term.reml_criterion_with_cache(
self.target.view(),
&rho,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
)?;
self.current_rho = rho.clone();
self.last_loss = Some(loss);
let n_obs = self.term.n_obs() as f64;
let sumsq = self.term.ard_coord_sumsq();
let traces = self
.term
.ard_inverse_traces(&cache)
.map_err(|e| format!("SaeManifoldOuterObjective::efs_step: ARD traces: {e}"))?;
let n_params = rho.to_flat().len();
let mut steps = vec![0.0_f64; n_params];
steps[0] = 0.0;
let lambda_smooth = rho.lambda_smooth();
let p_out = self.term.output_dim() as f64;
let mut smooth_rank_total = 0usize;
for atom in &self.term.atoms {
smooth_rank_total += SaeManifoldTerm::symmetric_rank(&atom.smooth_penalty)?;
}
let rank_total = p_out * (smooth_rank_total as f64);
let quad = self.term.decoder_smoothness_quadratic_form();
let eff_dof = self
.term
.decoder_smoothness_effective_dof(&cache, lambda_smooth)
.map_err(|e| format!("SaeManifoldOuterObjective::efs_step: smooth dof: {e}"))?;
if quad > 0.0 && rank_total - eff_dof > 0.0 && lambda_smooth > 0.0 {
let lambda_new = (rank_total - eff_dof) / quad;
if lambda_new.is_finite() && lambda_new > 0.0 {
steps[1] = lambda_new.ln() - rho.log_lambda_smooth;
}
}
let mut cursor = 2usize;
for (k, axis_logard) in rho.log_ard.iter().enumerate() {
let d = axis_logard.len();
for j in 0..d {
let denom = sumsq[k][j] + traces[k][j];
if denom > 0.0 {
let alpha_new = n_obs / denom;
if alpha_new.is_finite() && alpha_new > 0.0 {
steps[cursor + j] = alpha_new.ln() - axis_logard[j];
}
}
}
cursor += d;
}
let beta_hat = self.term.flatten_beta();
Ok(EfsEval {
cost,
steps,
beta: Some(beta_hat),
psi_gradient: None,
psi_indices: None,
inner_hessian_scale: None,
logdet_enclosure_gap: None,
})
}
fn ensure_outer_gradient_factor_well_conditioned(
cache: &ArrowFactorCache,
) -> Result<(), String> {
let pivot = arrow_factor_min_pivot(cache);
let Some(min_pivot) = pivot.min_pivot else {
return Err(
"analytic outer gradient undefined at this rho: joint Hessian numerically \
singular (no cached Cholesky pivots)"
.to_string(),
);
};
let Some(max_pivot) = arrow_factor_max_pivot(cache) else {
return Err(
"analytic outer gradient undefined at this rho: joint Hessian numerically \
singular (no cached Cholesky pivot scale)"
.to_string(),
);
};
let ratio = min_pivot / max_pivot;
if min_pivot.is_finite()
&& max_pivot.is_finite()
&& max_pivot > 0.0
&& ratio.is_finite()
&& ratio >= SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR
{
return Ok(());
}
Err(format!(
"analytic outer gradient undefined at this rho: joint Hessian numerically singular \
(min/max pivot ratio {ratio:.3e} < floor {floor:.3e}; min pivot {min_pivot:.3e}, \
max pivot {max_pivot:.3e})",
floor = SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR,
))
}
}
impl OuterObjective for SaeManifoldOuterObjective {
fn capability(&self) -> OuterCapability {
OuterCapability {
gradient: Derivative::Analytic,
hessian: DeclaredHessianForm::Unavailable,
n_params: self.baseline_rho.to_flat().len(),
psi_dim: 0,
fixed_point_available: true,
barrier_config: None,
prefer_gradient_only: false,
disable_fixed_point: false,
}
}
fn eval_cost(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
self.evaluate(rho.view())
.map(|(cost, _beta)| cost)
.map_err(EstimationError::RemlOptimizationFailed)
}
fn eval(&mut self, rho: &Array1<f64>) -> Result<OuterEval, EstimationError> {
let rho_state = self.baseline_rho.from_flat(rho.view());
if let Some(beta) = self.seeded_beta.take()
&& beta.len() == self.term.beta_dim()
{
self.term
.set_flat_beta(beta.view())
.map_err(EstimationError::RemlOptimizationFailed)?;
}
let (cost, loss, cache) = self
.term
.reml_criterion_with_cache(
self.target.view(),
&rho_state,
self.registry.as_ref(),
self.inner_max_iter,
self.learning_rate,
self.ridge_ext_coord,
self.ridge_beta,
)
.map_err(EstimationError::RemlOptimizationFailed)?;
Self::ensure_outer_gradient_factor_well_conditioned(&cache)
.map_err(EstimationError::RemlOptimizationFailed)?;
let components = self
.term
.analytic_outer_rho_gradient_components(&rho_state, &loss, &cache)
.map_err(EstimationError::RemlOptimizationFailed)?;
let gradient = components.gradient_with_available_correction();
self.current_rho = rho_state;
self.last_loss = Some(loss);
let beta_hat = self.term.flatten_beta();
Ok(OuterEval {
cost,
gradient,
hessian: HessianResult::Unavailable,
inner_beta_hint: Some(beta_hat),
})
}
fn eval_efs(&mut self, rho: &Array1<f64>) -> Result<EfsEval, EstimationError> {
self.efs_step(rho.view())
.map_err(EstimationError::RemlOptimizationFailed)
}
fn reset(&mut self) {
self.term = self.baseline_term.clone();
self.current_rho = self.baseline_rho.clone();
self.last_loss = None;
self.seeded_beta = None;
}
fn seed_inner_state(&mut self, beta: &Array1<f64>) -> Result<SeedOutcome, EstimationError> {
if beta.is_empty() {
return Ok(SeedOutcome::NoSlot);
}
if beta.len() != self.term.beta_dim() {
return Err(EstimationError::RemlOptimizationFailed(format!(
"SaeManifoldOuterObjective::seed_inner_state: β length {} != decoder dim {}",
beta.len(),
self.term.beta_dim()
)));
}
self.seeded_beta = Some(beta.clone());
Ok(SeedOutcome::Installed)
}
fn requires_continuation_path_entry(&self) -> bool {
self.term.k_atoms() >= 2
}
fn curvature_homotopy_entry(&mut self) -> Option<Result<bool, EstimationError>> {
Some(
self.run_curvature_homotopy_entry()
.map_err(EstimationError::RemlOptimizationFailed),
)
}
}
fn sae_manifold_newton_directional_decrease(
sys: &ArrowSchurSystem,
delta_ext_coord: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
) -> f64 {
assert_eq!(delta_ext_coord.len(), sys.row_offsets[sys.rows.len()]);
assert_eq!(delta_beta.len(), sys.k);
let mut gradient_dot_step = 0.0;
for (row_idx, row) in sys.rows.iter().enumerate() {
let row_base = sys.row_offsets[row_idx];
let di = sys.row_dims[row_idx];
for axis in 0..di {
gradient_dot_step += row.gt[axis] * delta_ext_coord[row_base + axis];
}
}
for idx in 0..sys.k {
gradient_dot_step += sys.gb[idx] * delta_beta[idx];
}
-gradient_dot_step
}
fn batched_smooth_sb(
sb_inputs: &[(ArrayView2<'_, f64>, ArrayView2<'_, f64>)],
symmetrize: bool,
) -> Vec<Array2<f64>> {
let n_atoms = sb_inputs.len();
let s_mats: Vec<Array2<f64>> = sb_inputs
.iter()
.map(|(s, _)| {
if symmetrize {
let m = s.nrows();
let mut sym = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
sym[[i, j]] = 0.5 * (s[[i, j]] + s[[j, i]]);
}
}
sym
} else {
s.to_owned()
}
})
.collect();
let cpu_one = |idx: usize| -> Array2<f64> { s_mats[idx].dot(&sb_inputs[idx].1) };
let rt = match crate::gpu::runtime::GpuRuntime::global() {
Some(rt) => rt,
None => return (0..n_atoms).map(cpu_one).collect(),
};
let mut groups: std::collections::BTreeMap<(usize, usize), Vec<usize>> =
std::collections::BTreeMap::new();
for (idx, (_, b)) in sb_inputs.iter().enumerate() {
let m = s_mats[idx].nrows();
let p = b.ncols();
groups.entry((m, p)).or_default().push(idx);
}
let mut out: Vec<Option<Array2<f64>>> = (0..n_atoms).map(|_| None).collect();
for ((m, p), members) in groups {
if members.len() < 2 || m == 0 || p == 0 {
for &idx in &members {
out[idx] = Some(cpu_one(idx));
}
continue;
}
let mut items: Vec<usize> = members.clone();
let s_ref = &s_mats;
let tile_results: std::sync::Mutex<Vec<(usize, Array2<f64>)>> =
std::sync::Mutex::new(Vec::with_capacity(members.len()));
let ok = crate::gpu::pool::scatter_batched(rt, &mut items, |_ordinal, slice| {
if slice.is_empty() {
return Some(());
}
let batch = slice.len();
let mut a = Array3::<f64>::zeros((batch, m, m));
let mut bt = Array3::<f64>::zeros((batch, p, m));
for (t, &idx) in slice.iter().enumerate() {
let s = &s_ref[idx];
let b = &sb_inputs[idx].1;
for i in 0..m {
for j in 0..m {
a[[t, i, j]] = s[[i, j]];
}
}
for i in 0..p {
for j in 0..m {
bt[[t, i, j]] = b[[j, i]];
}
}
}
let prod = crate::gpu::try_fast_abt_strided_batched(a.view(), bt.view())?;
let mut sink = tile_results.lock().expect("tile_results mutex poisoned");
for (t, &idx) in slice.iter().enumerate() {
sink.push((idx, prod.slice(s![t, .., ..]).to_owned()));
}
Some(())
});
match ok {
Some(()) => {
let sink = tile_results
.into_inner()
.expect("tile_results mutex poisoned");
for (idx, mat) in sink {
out[idx] = Some(mat);
}
for &idx in &members {
if out[idx].is_none() {
out[idx] = Some(cpu_one(idx));
}
}
}
None => {
for &idx in &members {
out[idx] = Some(cpu_one(idx));
}
}
}
}
out.into_iter()
.enumerate()
.map(|(idx, slot)| slot.unwrap_or_else(|| cpu_one(idx)))
.collect()
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct CurvatureBifurcation {
pub eta: f64,
pub min_pivot: f64,
}
#[derive(Debug, Clone)]
pub struct CurvatureWalkReport {
pub arrived: bool,
pub anchor_residual_norm_sq: f64,
pub bifurcation: Option<CurvatureBifurcation>,
pub eta_steps: usize,
pub step_halvings: usize,
pub collapse_events: usize,
pub reseeds: usize,
}
#[derive(Debug, Clone)]
pub struct LinearSpanAtomAnchor {
pub gate_weight: f64,
pub frame: GrassmannFrame,
pub decoder_coordinates: Array2<f64>,
pub singular_values: Array1<f64>,
}
#[derive(Debug, Clone)]
pub struct LinearSpanAnchor {
pub atoms: Vec<LinearSpanAtomAnchor>,
pub reconstruction: Array2<f64>,
pub residual_norm_sq: f64,
}
fn neutral_gate_weights(mode: AssignmentMode, k_atoms: usize) -> Array1<f64> {
match mode {
AssignmentMode::Softmax { .. } => Array1::from_elem(k_atoms, 1.0 / (k_atoms.max(1) as f64)),
AssignmentMode::IBPMap {
temperature, alpha, ..
} => ibp_map_row(Array1::<f64>::zeros(k_atoms).view(), temperature, alpha),
AssignmentMode::JumpReLU { .. } => Array1::from_elem(k_atoms, 0.5),
}
}
pub fn linear_span_anchor(
term: &SaeManifoldTerm,
targets: ArrayView2<'_, f64>,
) -> Result<LinearSpanAnchor, String> {
let n = term.n_obs();
let p = term.output_dim();
if targets.dim() != (n, p) {
return Err(format!(
"linear_span_anchor: targets shape {:?} != ({n}, {p})",
targets.dim()
));
}
if term.k_atoms() == 0 {
return Err("linear_span_anchor: term must contain at least one atom".into());
}
if !targets.iter().all(|v| v.is_finite()) {
return Err("linear_span_anchor: targets must be finite".into());
}
let gates = neutral_gate_weights(term.assignment.mode, term.k_atoms());
let mut residual = targets.to_owned();
let mut reconstruction = Array2::<f64>::zeros((n, p));
let mut atoms = Vec::with_capacity(term.k_atoms());
for (atom_idx, atom) in term.atoms.iter().enumerate() {
let gate = gates[atom_idx];
if !(gate.is_finite() && gate > 0.0) {
return Err(format!(
"linear_span_anchor: neutral gate for atom {atom_idx} must be positive finite; got {gate}"
));
}
let requested_rank = atom.basis_size().min(n).min(p);
if requested_rank == 0 {
return Err(format!(
"linear_span_anchor: atom {atom_idx} has no recoverable linear span rank"
));
}
let weighted = residual.mapv(|v| gate * v);
let (_u_opt, singular_values_full, vt_opt) = weighted
.svd(false, true)
.map_err(|err| format!("linear_span_anchor: SVD failed for atom {atom_idx}: {err}"))?;
let vt = vt_opt.ok_or_else(|| {
format!("linear_span_anchor: SVD returned no right factor for atom {atom_idx}")
})?;
let rank = requested_rank
.min(vt.nrows())
.min(singular_values_full.len());
if rank == 0 {
return Err(format!(
"linear_span_anchor: atom {atom_idx} SVD returned rank zero"
));
}
let mut frame = Array2::<f64>::zeros((p, rank));
for col in 0..rank {
for row in 0..p {
frame[[row, col]] = vt[[col, row]];
}
}
let singular_values = singular_values_full.slice(s![..rank]).to_owned();
let frame = GrassmannFrame::from_oriented(frame, singular_values.clone());
let frame_matrix = frame.frame().to_owned();
let mut coordinates = residual.dot(&frame_matrix);
coordinates.mapv_inplace(|v| v / gate);
let contribution = fast_abt(&coordinates, &frame_matrix).mapv(|v| gate * v);
reconstruction += &contribution;
residual -= &contribution;
atoms.push(LinearSpanAtomAnchor {
gate_weight: gate,
frame,
decoder_coordinates: coordinates,
singular_values,
});
}
let residual_norm_sq = residual.iter().map(|v| v * v).sum();
Ok(LinearSpanAnchor {
atoms,
reconstruction,
residual_norm_sq,
})
}
fn sae_cholesky_solve_neg_gradient(
h: ArrayView2<'_, f64>,
g: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, String> {
let n = h.nrows();
if h.ncols() != n || g.len() != n {
return Err(format!(
"sae_cholesky_solve_neg_gradient: shape mismatch H={:?}, g={}",
h.dim(),
g.len()
));
}
let mut l = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = h[[i, j]];
for k in 0..j {
sum -= l[[i, k]] * l[[j, k]];
}
if i == j {
if !(sum.is_finite() && sum > 0.0) {
return Err(format!("non-positive Cholesky pivot at {i}: {sum}"));
}
l[[i, j]] = sum.sqrt();
} else {
l[[i, j]] = sum / l[[j, j]];
}
}
}
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut sum = -g[i];
for k in 0..i {
sum -= l[[i, k]] * y[k];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f64>::zeros(n);
for ii in 0..n {
let i = n - 1 - ii;
let mut sum = y[i];
for k in i + 1..n {
sum -= l[[k, i]] * x[k];
}
x[i] = sum / l[[i, i]];
}
if !x.iter().all(|v| v.is_finite()) {
return Err("sae_cholesky_solve_neg_gradient: non-finite solution".into());
}
Ok(x)
}
fn softmax_row(logits: ArrayView1<'_, f64>, temperature: f64) -> Array1<f64> {
let k = logits.len();
let inv_tau = 1.0 / temperature;
let mut max_logit = f64::NEG_INFINITY;
for &v in logits.iter() {
max_logit = max_logit.max(v);
}
let mut out = Array1::<f64>::zeros(k);
let mut sum = 0.0;
for i in 0..k {
let v = ((logits[i] - max_logit) * inv_tau).exp();
out[i] = v;
sum += v;
}
assert!(sum.is_finite() && sum > 0.0);
for v in out.iter_mut() {
*v /= sum;
}
out
}
fn validate_finite_logits(logits: ArrayView1<'_, f64>, row: usize) -> Result<(), String> {
for (col, &v) in logits.iter().enumerate() {
if !v.is_finite() {
return Err(format!(
"SaeAssignment: non-finite assignment logit at row {row}, atom {col}: {v}"
));
}
}
Ok(())
}
fn solve_basis_transport(
new_phi: ArrayView2<'_, f64>,
old_phi: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, String> {
solve_design_least_squares(new_phi, old_phi)
}
fn solve_design_least_squares(
design: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, String> {
if design.nrows() != rhs.nrows() {
return Err(format!(
"solve_design_least_squares: row mismatch design={} rhs={}",
design.nrows(),
rhs.nrows()
));
}
let (u_opt, sigma, vt_opt) = design
.to_owned()
.svd(true, true)
.map_err(|err| format!("solve_design_least_squares: SVD failed: {err}"))?;
let u = u_opt.ok_or_else(|| "solve_design_least_squares: SVD omitted U".to_string())?;
let vt = vt_opt.ok_or_else(|| "solve_design_least_squares: SVD omitted Vt".to_string())?;
let smax = sigma.iter().fold(0.0_f64, |acc, &v| acc.max(v));
if !(smax.is_finite() && smax > 0.0) {
return Err("solve_design_least_squares: design has zero numerical rank".to_string());
}
let cutoff = smax * f64::EPSILON * (design.nrows().max(design.ncols()) as f64);
let coeffs = u.t().dot(&rhs);
let mut scaled = Array2::<f64>::zeros(coeffs.dim());
for row in 0..sigma.len() {
if sigma[row] > cutoff {
let inv = 1.0 / sigma[row];
for col in 0..rhs.ncols() {
scaled[[row, col]] = inv * coeffs[[row, col]];
}
}
}
Ok(vt.t().dot(&scaled))
}
fn canonicalize_softmax_logits(logits: &mut Array2<f64>) {
let k = logits.ncols();
if k == 0 {
return;
}
if k == 1 {
logits.fill(0.0);
return;
}
for row in 0..logits.nrows() {
let reference = logits[[row, k - 1]];
for col in 0..k - 1 {
logits[[row, col]] -= reference;
}
logits[[row, k - 1]] = 0.0;
}
}
fn ibp_stick_breaking_prior(k_atoms: usize, alpha: f64) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(k_atoms);
let log_ratio = (alpha / (alpha + 1.0)).ln();
for k in 0..k_atoms {
let log_pi = (k as f64) * log_ratio;
out[k] = log_pi.exp().max(f64::MIN_POSITIVE);
}
out
}
fn ibp_map_row(logits: ArrayView1<'_, f64>, temperature: f64, alpha: f64) -> Array1<f64> {
let prior = ibp_stick_breaking_prior(logits.len(), alpha);
let mut out = Array1::<f64>::zeros(logits.len());
for i in 0..logits.len() {
out[i] = crate::linalg::utils::stable_logistic(logits[i] / temperature) * prior[i];
}
out
}
#[must_use]
pub fn ibp_map_row_value_grad(
logits: ArrayView1<'_, f64>,
temperature: f64,
alpha: f64,
) -> (Array1<f64>, Array1<f64>) {
let prior = ibp_stick_breaking_prior(logits.len(), alpha);
let inv_tau = 1.0 / temperature;
let mut value = Array1::<f64>::zeros(logits.len());
let mut grad = Array1::<f64>::zeros(logits.len());
for i in 0..logits.len() {
let sig = crate::linalg::utils::stable_logistic(logits[i] * inv_tau);
value[i] = sig * prior[i];
grad[i] = sig * (1.0 - sig) * inv_tau * prior[i];
}
(value, grad)
}
fn jumprelu_row(logits: ArrayView1<'_, f64>, temperature: f64, threshold: f64) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(logits.len());
for i in 0..logits.len() {
if logits[i] > threshold {
out[i] = crate::linalg::utils::stable_logistic((logits[i] - threshold) / temperature);
}
}
out
}
struct ActiveAtomLogitJvp<'a> {
mode: AssignmentMode,
k: usize,
logit_k: f64,
a_k: f64,
decoded_k: ArrayView1<'a, f64>,
fitted: ArrayView1<'a, f64>,
ibp_prior: Option<&'a [f64]>,
compact_index: usize,
}
fn fill_active_atom_logit_jvp(input: ActiveAtomLogitJvp<'_>, jac_compact: &mut Array2<f64>) {
let ActiveAtomLogitJvp {
mode,
k,
logit_k,
a_k,
decoded_k,
fitted,
ibp_prior,
compact_index,
} = input;
let p = fitted.len();
match mode {
AssignmentMode::Softmax { temperature, .. } => {
let inv_tau = 1.0 / temperature;
for out_col in 0..p {
jac_compact[[compact_index, out_col]] =
a_k * (decoded_k[out_col] - fitted[out_col]) * inv_tau;
}
}
AssignmentMode::IBPMap { temperature, .. } => {
let inv_tau = 1.0 / temperature;
let prior =
ibp_prior.expect("fill_active_atom_logit_jvp: IBPMap requires precomputed prior");
let pi_k = prior[k];
let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
let dz = sig * (1.0 - sig) * inv_tau * pi_k;
for out_col in 0..p {
jac_compact[[compact_index, out_col]] = dz * decoded_k[out_col];
}
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
if logit_k <= threshold {
return;
}
let inv_tau = 1.0 / temperature;
let activation = crate::linalg::utils::stable_logistic((logit_k - threshold) * inv_tau);
let da = activation * (1.0 - activation) * inv_tau;
for out_col in 0..p {
jac_compact[[compact_index, out_col]] = da * decoded_k[out_col];
}
}
}
}
fn fill_assignment_logit_jvp_rows(
mode: AssignmentMode,
logits: ArrayView1<'_, f64>,
assignments: ArrayView1<'_, f64>,
decoded: ArrayView2<'_, f64>,
fitted: ArrayView1<'_, f64>,
ibp_prior: Option<&[f64]>,
local_jac: &mut Array2<f64>,
) {
match mode {
AssignmentMode::Softmax { temperature, .. } => {
if assignments.len() == 1 {
return;
}
let inv_tau = 1.0 / temperature;
for logit_col in 0..assignments.len() - 1 {
for out_col in 0..fitted.len() {
local_jac[[logit_col, out_col]] = assignments[logit_col]
* (decoded[[logit_col, out_col]] - fitted[out_col])
* inv_tau;
}
}
}
AssignmentMode::IBPMap { temperature, .. } => {
let inv_tau = 1.0 / temperature;
let prior = ibp_prior
.expect("fill_assignment_logit_jvp_rows: IBPMap requires precomputed prior");
for logit_col in 0..assignments.len() {
let pi_k = prior[logit_col];
let a_k = assignments[logit_col];
let sig = if pi_k > 0.0 { a_k / pi_k } else { 0.0 };
let dz = sig * (1.0 - sig) * inv_tau * pi_k;
for out_col in 0..fitted.len() {
local_jac[[logit_col, out_col]] = dz * decoded[[logit_col, out_col]];
}
}
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let inv_tau = 1.0 / temperature;
for logit_col in 0..assignments.len() {
if logits[logit_col] <= threshold {
continue;
}
let activation = crate::linalg::utils::stable_logistic(
(logits[logit_col] - threshold) * inv_tau,
);
let da = activation * (1.0 - activation) * inv_tau;
for out_col in 0..fitted.len() {
local_jac[[logit_col, out_col]] = da * decoded[[logit_col, out_col]];
}
}
}
}
}
fn flat_logits(logits: ArrayView2<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(logits.len());
for row in 0..logits.nrows() {
let start = row * logits.ncols();
for col in 0..logits.ncols() {
out[start + col] = logits[[row, col]];
}
}
out
}
fn sae_coord_penalty_offset(
row_layout: Option<&SaeRowLayout>,
dense_off: usize,
row: usize,
atom_idx: usize,
) -> Option<usize> {
match row_layout {
Some(layout) => {
let active = &layout.active_atoms[row];
let starts = &layout.coord_starts[row];
active
.iter()
.zip(starts.iter())
.find_map(|(&active_atom, &coord_start)| {
if active_atom == atom_idx {
Some(coord_start)
} else {
None
}
})
}
None => Some(dense_off),
}
}
fn assignment_prior_value(assignment: &SaeAssignment, rho: &SaeManifoldRho) -> f64 {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)
.expect("assignment logits must be finite");
}
let target = flat_logits(assignment.logits.view());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return 0.0;
}
match assignment.mode {
AssignmentMode::Softmax {
temperature,
sparsity,
} => {
let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
penalty.value(target.view(), rho_view.view())
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
let rho_view = if learnable_alpha {
Array1::from_vec(vec![rho.log_lambda_sparse])
} else {
Array1::zeros(0)
};
penalty.value(target.view(), rho_view.view())
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let sparsity_strength = rho.lambda_sparse();
let mut acc = 0.0;
for &logit in target.iter() {
if jumprelu_in_optimization_band(logit, threshold, temperature) {
acc += crate::linalg::utils::stable_logistic((logit - threshold) / temperature);
}
}
sparsity_strength * acc
}
}
}
fn assignment_prior_log_strength_derivative(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> f64 {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)
.expect("assignment logits must be finite");
}
let target = flat_logits(assignment.logits.view());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return 0.0;
}
match assignment.mode {
AssignmentMode::Softmax { .. } | AssignmentMode::JumpReLU { .. } => {
assignment_prior_value(assignment, rho)
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let mut penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
if learnable_alpha {
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse]);
penalty.grad_rho(target.view(), rho_view.view())[0]
} else {
penalty.weight = rho.lambda_sparse();
penalty.value(target.view(), Array1::<f64>::zeros(0).view())
}
}
}
}
fn assignment_prior_log_strength_hdiag(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> Result<Array1<f64>, String> {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)?;
}
let target = flat_logits(assignment.logits.view());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return Ok(Array1::<f64>::zeros(target.len()));
}
match assignment.mode {
AssignmentMode::Softmax {
temperature,
sparsity,
} => {
let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
penalty
.hessian_diag(target.view(), rho_view.view())
.ok_or_else(|| {
"softmax assignment log-strength hessian diag unavailable".to_string()
})
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let sparsity_strength = rho.lambda_sparse();
let inv_tau = 1.0 / temperature;
let inv_tau2 = inv_tau * inv_tau;
let mut d = Array1::<f64>::zeros(target.len());
for idx in 0..target.len() {
let logit = target[idx];
if !jumprelu_in_optimization_band(logit, threshold, temperature) {
continue;
}
let activation =
crate::linalg::utils::stable_logistic((logit - threshold) * inv_tau);
let slope = activation * (1.0 - activation);
d[idx] = sparsity_strength * slope * slope * inv_tau2;
}
Ok(d)
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
if learnable_alpha {
return Ok(Array1::<f64>::zeros(target.len()));
}
let mut penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
penalty.weight = rho.lambda_sparse();
penalty
.hessian_diag(target.view(), Array1::<f64>::zeros(0).view())
.ok_or_else(|| "IBP assignment log-strength hessian diag unavailable".to_string())
}
}
}
fn assignment_prior_grad_hdiag(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> Result<(Array1<f64>, Array1<f64>), String> {
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)?;
}
let target = flat_logits(assignment.logits.view());
let mut grad = Array1::<f64>::zeros(target.len());
let mut diag = Array1::<f64>::zeros(target.len());
if matches!(assignment.mode, AssignmentMode::Softmax { .. }) && assignment.k_atoms() == 1 {
return Ok((grad, diag));
}
let (sparsity_grad, sparsity_diag) = match assignment.mode {
AssignmentMode::Softmax {
temperature,
sparsity,
} => {
let penalty = SoftmaxAssignmentSparsityPenalty::new(assignment.k_atoms(), temperature);
let rho_view = Array1::from_vec(vec![rho.log_lambda_sparse + sparsity.ln()]);
let g = penalty.grad_target(target.view(), rho_view.view());
let d = penalty
.hessian_diag(target.view(), rho_view.view())
.ok_or_else(|| "softmax assignment hessian diag unavailable".to_string())?;
(g, d)
}
AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} => {
let mut penalty = IBPAssignmentPenalty::new(
assignment.k_atoms(),
alpha,
temperature,
learnable_alpha,
);
let rho_view = if learnable_alpha {
Array1::from_vec(vec![rho.log_lambda_sparse])
} else {
penalty.weight = rho.lambda_sparse();
Array1::zeros(0)
};
let g = penalty.grad_target(target.view(), rho_view.view());
let d = penalty
.hessian_diag(target.view(), rho_view.view())
.ok_or_else(|| "IBP assignment hessian diag unavailable".to_string())?;
(g, d)
}
AssignmentMode::JumpReLU {
temperature,
threshold,
} => {
let sparsity_strength = rho.lambda_sparse();
let inv_tau = 1.0 / temperature;
let inv_tau2 = inv_tau * inv_tau;
let mut g = Array1::<f64>::zeros(target.len());
let mut d = Array1::<f64>::zeros(target.len());
for idx in 0..target.len() {
let logit = target[idx];
if !jumprelu_in_optimization_band(logit, threshold, temperature) {
continue;
}
let activation =
crate::linalg::utils::stable_logistic((logit - threshold) * inv_tau);
let slope = activation * (1.0 - activation);
g[idx] = sparsity_strength * slope * inv_tau;
d[idx] = sparsity_strength * slope * slope * inv_tau2;
}
(g, d)
}
};
grad += &sparsity_grad;
diag += &sparsity_diag;
Ok((grad, diag))
}
fn ibp_assignment_third_channels(
assignment: &SaeAssignment,
rho: &SaeManifoldRho,
) -> Result<Option<IbpHessianDiagThirdChannels>, String> {
let AssignmentMode::IBPMap {
temperature,
alpha,
learnable_alpha,
} = assignment.mode
else {
return Ok(None);
};
for row in 0..assignment.n_obs() {
validate_finite_logits(assignment.logits.row(row), row)?;
}
let target = flat_logits(assignment.logits.view());
let mut penalty =
IBPAssignmentPenalty::new(assignment.k_atoms(), alpha, temperature, learnable_alpha);
let rho_view = if learnable_alpha {
Array1::from_vec(vec![rho.log_lambda_sparse])
} else {
penalty.weight = rho.lambda_sparse();
Array1::zeros(0)
};
Ok(Some(penalty.hessian_diag_logit_third_channels(
target.view(),
rho_view.view(),
)))
}
fn sae_penalty_is_row_block_supported(penalty: &AnalyticPenaltyKind) -> bool {
matches!(
penalty,
AnalyticPenaltyKind::Ard(_)
| AnalyticPenaltyKind::TopKActivation(_)
| AnalyticPenaltyKind::JumpReLU(_)
| AnalyticPenaltyKind::Sparsity(_)
| AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
| AnalyticPenaltyKind::IBPAssignment(_)
| AnalyticPenaltyKind::RowPrecisionPrior(_)
| AnalyticPenaltyKind::ParametricRowPrecisionPrior(_)
| AnalyticPenaltyKind::ScadMcp(_)
| AnalyticPenaltyKind::BlockOrthogonality(_)
| AnalyticPenaltyKind::Isometry(_)
)
}
fn sae_coord_penalty_is_origin_anchored_magnitude(penalty: &AnalyticPenaltyKind) -> bool {
matches!(penalty, AnalyticPenaltyKind::ScadMcp(_))
}
fn sae_coord_penalty_euclidean_restriction(
coord: &LatentCoordValues,
) -> Option<(Vec<usize>, Array1<f64>)> {
let periods = coord.effective_axis_periods();
let d = periods.len();
let euclidean_axes: Vec<usize> = (0..d).filter(|&axis| periods[axis].is_none()).collect();
if euclidean_axes.len() == d {
return None;
}
let n = coord.n_obs();
let de = euclidean_axes.len();
let flat = coord.as_flat();
let mut compacted = Array1::<f64>::zeros(n * de);
for row in 0..n {
for (j, &axis) in euclidean_axes.iter().enumerate() {
compacted[row * de + j] = flat[row * d + axis];
}
}
Some((euclidean_axes, compacted))
}
pub fn sae_row_block_penalty_kinds() -> &'static [&'static str] {
&[
"ard",
"top_k_activation",
"jumprelu",
"sparsity",
"softmax_assignment_sparsity",
"ibp_assignment",
"row_precision_prior",
"parametric_row_precision_prior",
"scad_mcp",
"block_orthogonality",
"isometry",
]
}
#[must_use = "build error must be handled"]
pub fn term_from_padded_blocks_with_mode(
n_obs: usize,
p_out: usize,
basis_kinds: &[SaeAtomBasisKind],
basis_values: ArrayView3<'_, f64>,
basis_jacobian: ArrayView4<'_, f64>,
basis_sizes: &[usize],
latent_dims: &[usize],
decoder_coefficients: ArrayView3<'_, f64>,
smooth_penalties: ArrayView3<'_, f64>,
logits: ArrayView2<'_, f64>,
coords: &[Array2<f64>],
mode: AssignmentMode,
evaluators: &[Option<Arc<dyn SaeBasisEvaluator>>],
) -> Result<SaeManifoldTerm, String> {
let k_atoms = basis_sizes.len();
if latent_dims.len() != k_atoms || basis_kinds.len() != k_atoms || coords.len() != k_atoms {
return Err("term_from_padded_blocks: K-length metadata mismatch".into());
}
if !evaluators.is_empty() && evaluators.len() != k_atoms {
return Err(format!(
"term_from_padded_blocks: evaluators length {} must equal K={k_atoms} or be empty",
evaluators.len()
));
}
if logits.dim() != (n_obs, k_atoms) {
return Err(format!(
"term_from_padded_blocks: logits must be ({n_obs}, {k_atoms}); got {:?}",
logits.dim()
));
}
let mut atoms = Vec::with_capacity(k_atoms);
for k in 0..k_atoms {
let m = basis_sizes[k];
let d = latent_dims[k];
let phi = basis_values.slice(s![k, 0..n_obs, 0..m]).to_owned();
let jet = basis_jacobian.slice(s![k, 0..n_obs, 0..m, 0..d]).to_owned();
let b = decoder_coefficients.slice(s![k, 0..m, 0..p_out]).to_owned();
let s = smooth_penalties.slice(s![k, 0..m, 0..m]).to_owned();
let atom = SaeManifoldAtom::new(
format!("atom_{k}"),
basis_kinds[k].clone(),
d,
phi,
jet,
b,
s,
)?;
let atom = match evaluators.get(k).and_then(|slot| slot.clone()) {
Some(evaluator) => atom.with_basis_evaluator(evaluator),
None => atom,
};
atoms.push(atom);
}
let manifolds = basis_kinds
.iter()
.zip(latent_dims.iter().copied())
.map(|(kind, d)| kind.latent_manifold(d))
.collect();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits.to_owned(),
coords.to_vec(),
manifolds,
mode,
)?;
SaeManifoldTerm::new(atoms, assignment)
}
pub fn refresh_isometry_caches_from_atom(
penalty: &IsometryPenalty,
atom: &SaeManifoldAtom,
coords: ArrayView2<'_, f64>,
) -> Result<bool, String> {
let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
format!(
"refresh_isometry_caches_from_atom: atom {} has no basis evaluator",
atom.name
)
})?;
let (_phi, jet) = evaluator.evaluate(coords)?;
let n_obs = coords.nrows();
let d = atom.latent_dim;
let m = atom.basis_size();
let p = atom.decoder_coefficients.ncols();
if penalty.p_out != p {
return Err(format!(
"refresh_isometry_caches_from_atom: penalty.p_out={} but atom.decoder.cols={p}",
penalty.p_out
));
}
if jet.dim() != (n_obs, m, d) {
return Err(format!(
"refresh_isometry_caches_from_atom: evaluator first jet has shape {:?}, expected ({n_obs}, {m}, {d})",
jet.dim()
));
}
let b = &atom.decoder_coefficients;
let mut jac = Array2::<f64>::zeros((n_obs, p * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += jet[[n, mm, a]] * b[[mm, i]];
}
jac[[n, i * d + a]] = acc;
}
}
}
let jac2_opt = if let Some(second_eval) = atom.basis_second_jet.as_ref() {
let hess = second_eval.second_jet(coords)?;
if hess.dim() != (n_obs, m, d, d) {
return Err(format!(
"refresh_isometry_caches_from_atom: evaluator second jet has shape {:?}, expected ({n_obs}, {m}, {d}, {d})",
hess.dim()
));
}
let mut jac2 = Array2::<f64>::zeros((n_obs, p * d * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
for c in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += hess[[n, mm, a, c]] * b[[mm, i]];
}
jac2[[n, (i * d + a) * d + c]] = acc;
}
}
}
}
Some(Arc::new(jac2))
} else {
None
};
let jac3_opt = if penalty.duchon_radial_source.is_none() {
match evaluator.third_jet_dyn(coords) {
Some(third) => {
let t3 = third?;
if t3.dim() != (n_obs, m, d, d, d) {
return Err(format!(
"refresh_isometry_caches_from_atom: evaluator third jet has shape {:?}, expected ({n_obs}, {m}, {d}, {d}, {d})",
t3.dim()
));
}
let mut jac3 = Array3::<f64>::zeros((n_obs, p, d * d * d));
for n in 0..n_obs {
for i in 0..p {
for a in 0..d {
for c in 0..d {
for e in 0..d {
let mut acc = 0.0;
for mm in 0..m {
acc += t3[[n, mm, a, c, e]] * b[[mm, i]];
}
jac3[[n, i, ((a * d) + c) * d + e]] = acc;
}
}
}
}
}
Some(Arc::new(jac3))
}
None => None,
}
} else {
None
};
let installed = jac2_opt.is_some();
penalty.refresh_caches(Some(Arc::new(jac)), jac2_opt);
penalty.set_third_decoder_derivative(jac3_opt);
Ok(installed)
}
pub fn refresh_isometry_caches_from_term(
registry: &AnalyticPenaltyRegistry,
term: &SaeManifoldTerm,
coords_per_atom: &[Array2<f64>],
) -> Result<usize, String> {
if coords_per_atom.len() != term.atoms.len() {
return Err(format!(
"refresh_isometry_caches_from_term: coords_per_atom length {} != number of atoms {}",
coords_per_atom.len(),
term.atoms.len()
));
}
let mut refreshed_with_second = 0usize;
let mut consumed_per_signature: std::collections::HashMap<(usize, usize), usize> =
std::collections::HashMap::new();
for entry in registry.penalties.iter() {
let AnalyticPenaltyKind::Isometry(p) = entry else {
continue;
};
let Some(p_latent_dim) = p.target.latent_dim else {
continue;
};
let signature = (p_latent_dim, p.p_out);
let already_consumed = consumed_per_signature.entry(signature).or_insert(0);
let mut seen = 0usize;
let mut paired: Option<usize> = None;
for (atom_idx, atom) in term.atoms.iter().enumerate() {
let matches = atom.latent_dim == p_latent_dim
&& atom.decoder_coefficients.ncols() == p.p_out
&& atom.basis_evaluator.is_some();
if !matches {
continue;
}
if seen == *already_consumed {
paired = Some(atom_idx);
break;
}
seen += 1;
}
let Some(atom_idx) = paired else {
continue;
};
*already_consumed += 1;
let atom = &term.atoms[atom_idx];
let coords = coords_per_atom[atom_idx].view();
if refresh_isometry_caches_from_atom(p, atom, coords)? {
refreshed_with_second += 1;
}
}
Ok(refreshed_with_second)
}
pub fn grassmann_recover_planted_span_angle(
targets: ArrayView2<'_, f64>,
coords: ArrayView2<'_, f64>,
planted: ArrayView2<'_, f64>,
) -> Result<f64, String> {
let p = targets.ncols();
let r = coords.ncols();
if planted.dim() != (p, r) {
return Err(format!(
"grassmann_recover_planted_span_angle: planted frame must be ({p}, {r}); got {:?}",
planted.dim()
));
}
let mut cross = GrassmannCrossMoment::new(p, r);
cross.accumulate(targets, coords)?;
let frame = cross.polar_frame()?;
frame.max_principal_angle(planted)
}
pub fn grassmann_assert_border_dim_invariant(term: &SaeManifoldTerm) -> Result<(), String> {
let expected: usize = term
.atoms
.iter()
.map(|a| a.basis_size() * a.border_frame_rank())
.sum();
let got = term.factored_border_dim();
if got != expected {
return Err(format!(
"grassmann border-dim invariant violated: factored_border_dim() = {got}, \
expected Σ M_k·r_k = {expected}"
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::arrow_schur::{
ArrowFactorSlab, ArrowHtbetaCache, ArrowSolverMode, ArrowUndampedFactors, PcgDiagnostics,
};
use crate::terms::analytic_penalties::ARDPenalty;
use crate::terms::analytic_penalties::IsometryReference;
use approx::assert_abs_diff_eq;
use ndarray::array;
fn assert_matrix_same_bits(left: &Array2<f64>, right: &Array2<f64>) {
assert_eq!(left.dim(), right.dim());
for ((row, col), &value) in left.indexed_iter() {
assert_eq!(
value.to_bits(),
right[[row, col]].to_bits(),
"matrix bits differ at ({row}, {col})"
);
}
}
fn assert_tensor3_same_bits(left: &Array3<f64>, right: &Array3<f64>) {
assert_eq!(left.dim(), right.dim());
for ((row, col, axis), &value) in left.indexed_iter() {
assert_eq!(
value.to_bits(),
right[[row, col, axis]].to_bits(),
"tensor bits differ at ({row}, {col}, {axis})"
);
}
}
fn assert_eta_one_parity(
evaluator: &dyn SaeBasisEvaluator,
coords: ArrayView2<'_, f64>,
expected_curved: usize,
) {
let (phi, jet) = evaluator.evaluate(coords).expect("base evaluate");
let eta = evaluator
.evaluate_phi_eta(coords, 1.0)
.expect("eta evaluate");
assert_matrix_same_bits(&eta.phi, &phi);
assert_tensor3_same_bits(&eta.jet, &jet);
assert_eq!(eta.split.curved_cols.len(), expected_curved);
for &col in &eta.split.linear_cols {
for row in 0..phi.nrows() {
assert_eq!(eta.dphi_deta[[row, col]], 0.0);
for axis in 0..jet.shape()[2] {
assert_eq!(eta.djet_deta[[row, col, axis]], 0.0);
}
}
}
for &col in &eta.split.curved_cols {
for row in 0..phi.nrows() {
assert_eq!(
eta.dphi_deta[[row, col]].to_bits(),
phi[[row, col]].to_bits()
);
for axis in 0..jet.shape()[2] {
assert_eq!(
eta.djet_deta[[row, col, axis]].to_bits(),
jet[[row, col, axis]].to_bits()
);
}
}
}
}
#[test]
fn phi_eta_one_reproduces_current_atom_bases_bit_for_bit() {
let periodic_coords = array![[0.0_f64], [0.125], [0.4]];
let periodic = PeriodicHarmonicEvaluator::new(7).unwrap();
assert_eta_one_parity(&periodic, periodic_coords.view(), 4);
let raw_circle_coords = array![[0.0_f64], [0.3], [1.1]];
let raw_circle = RawPeriodicCircleEvaluator::new(1).unwrap();
assert_eta_one_parity(&raw_circle, raw_circle_coords.view(), 0);
let torus_coords = array![[0.0_f64, 0.2], [0.25, 0.5], [0.7, 0.9]];
let torus = TorusHarmonicEvaluator::new(2, 2).unwrap();
assert_eta_one_parity(&torus, torus_coords.view(), 20);
let sphere_coords = array![[0.0_f64, 0.0], [0.3, 0.4], [-0.2, 1.1]];
let sphere = SphereChartEvaluator;
assert_eta_one_parity(&sphere, sphere_coords.view(), 3);
let centers = array![
[-1.0_f64, -1.0],
[1.0, -1.0],
[-1.0, 1.0],
[1.0, 1.0],
[0.0, 0.0],
[0.5, -0.25]
];
let duchon_coords = array![[0.1_f64, 0.2], [0.4, -0.3], [-0.2, 0.7]];
let duchon = DuchonCoordinateEvaluator::new(centers, 2).unwrap();
let (duchon_phi, _) = duchon.evaluate(duchon_coords.view()).unwrap();
let duchon_poly = 3usize;
assert_eta_one_parity(
&duchon,
duchon_coords.view(),
duchon_phi.ncols() - duchon_poly,
);
let euclidean = EuclideanPatchEvaluator::new(2, 3).unwrap();
let total_cols = crate::basis::monomial_exponents(2, 3).len();
let linear_cols = crate::basis::monomial_exponents(2, 3)
.iter()
.filter(|alpha| alpha.iter().sum::<usize>() <= 1)
.count();
assert_eta_one_parity(&euclidean, duchon_coords.view(), total_cols - linear_cols);
}
#[test]
fn linear_span_anchor_recovers_planted_two_plane_configuration() {
let n = 4usize;
let p = 4usize;
let phi = Array2::<f64>::ones((n, 2));
let jet = Array3::<f64>::zeros((n, 2, 1));
let decoder = Array2::<f64>::zeros((2, p));
let smooth = Array2::<f64>::eye(2);
let atoms = vec![
SaeManifoldAtom::new(
"plane0",
SaeAtomBasisKind::EuclideanPatch,
1,
phi.clone(),
jet.clone(),
decoder.clone(),
smooth.clone(),
)
.unwrap(),
SaeManifoldAtom::new(
"plane1",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
smooth,
)
.unwrap(),
];
let coords = vec![Array2::<f64>::zeros((n, 1)), Array2::<f64>::zeros((n, 1))];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 2)),
coords,
vec![LatentManifold::Euclidean, LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let target = array![
[3.0_f64, 0.0, 0.0, 0.0],
[0.0, 2.0, 0.0, 0.0],
[0.0, 0.0, 1.5, 0.0],
[0.0, 0.0, 0.0, 1.0]
];
let anchor = linear_span_anchor(&term, target.view()).unwrap();
assert_eq!(anchor.atoms.len(), 2);
assert_abs_diff_eq!(anchor.residual_norm_sq, 0.0, epsilon = 1.0e-18);
let plane0 = array![[1.0_f64, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 0.0]];
let plane1 = array![[0.0_f64, 0.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let angle0 = anchor.atoms[0]
.frame
.max_principal_angle(plane0.view())
.unwrap();
let angle1 = anchor.atoms[1]
.frame
.max_principal_angle(plane1.view())
.unwrap();
assert_abs_diff_eq!(angle0, 0.0, epsilon = 1.0e-12);
assert_abs_diff_eq!(angle1, 0.0, epsilon = 1.0e-12);
}
fn circle_certificate_fixture(radius: f64, planes: &[(usize, usize)]) -> SaeManifoldTerm {
let n = 16usize;
let p = 4usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let coords = Array2::<f64>::from_shape_fn((n, 1), |(row, _)| row as f64 / n as f64);
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let mut atoms = Vec::with_capacity(planes.len());
let mut coord_blocks = Vec::with_capacity(planes.len());
for (atom_idx, &(axis_sin, axis_cos)) in planes.iter().enumerate() {
let mut decoder = Array2::<f64>::zeros((3, p));
decoder[[1, axis_sin]] = radius;
decoder[[2, axis_cos]] = radius;
let atom = SaeManifoldAtom::new(
format!("circle_{atom_idx}"),
SaeAtomBasisKind::Periodic,
1,
phi.clone(),
jet.clone(),
decoder,
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_second_jet(evaluator.clone());
atoms.push(atom);
coord_blocks.push(coords.clone());
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, planes.len())),
coord_blocks,
vec![LatentManifold::Circle { period: 1.0 }; planes.len()],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(atoms, assignment).unwrap();
term.set_certificate_dispersion(1.0).unwrap();
term
}
#[test]
fn dictionary_incoherence_report_orthogonal_frames_has_zero_mu_hat() {
let term = circle_certificate_fixture(2.0, &[(0, 1), (2, 3)]);
let report = dictionary_incoherence_report(&term).unwrap();
assert_abs_diff_eq!(report.mu_hat, 0.0, epsilon = 1.0e-12);
assert_eq!(report.per_atom_kappa_hat.len(), 2);
let kappa_max = report
.per_atom_kappa_hat
.iter()
.copied()
.fold(0.0_f64, f64::max);
let recomputed = curved_dictionary_global_optimality_verdict(
report.mu_hat,
kappa_max,
report.peak_activity_floor,
report.snr_proxy,
report.per_atom_kappa_hat.len(),
);
assert_eq!(report.global_optimality, recomputed);
if report.snr_proxy > 1.0 {
assert!(
report.global_optimality.is_certified(),
"μ̂=0, κ̂=0.5<1, SNR>1 ⇒ must certify; got {}",
report.note
);
}
}
#[test]
fn dictionary_incoherence_report_coherent_frames_has_unit_mu_hat() {
let term = circle_certificate_fixture(2.0, &[(0, 1), (0, 1)]);
let report = dictionary_incoherence_report(&term).unwrap();
assert_abs_diff_eq!(report.mu_hat, 1.0, epsilon = 1.0e-12);
}
#[test]
fn dictionary_incoherence_report_circle_kappa_matches_inverse_radius() {
let radius = 2.5_f64;
let mut term = circle_certificate_fixture(radius, &[(0, 1)]);
term.set_certificate_dispersion(0.25).unwrap();
let report = dictionary_incoherence_report(&term).unwrap();
assert_abs_diff_eq!(
report.per_atom_kappa_hat[0],
1.0 / radius,
epsilon = 1.0e-10
);
assert!(report.snr_proxy.is_finite() && report.snr_proxy > 0.0);
assert_abs_diff_eq!(report.mean_activity_floor, 1.0, epsilon = 1.0e-12);
assert_abs_diff_eq!(report.peak_activity_floor, 1.0, epsilon = 1.0e-12);
}
#[test]
fn search_strategy_exposes_fixed_and_sweep_values() {
assert!(SearchStrategy::Fixed.is_fixed());
let strategy = SearchStrategy::ExponentialSweep {
values: vec![0.1, 1.0, 10.0],
};
assert!(!strategy.is_fixed());
assert_eq!(strategy.sweep_values(), Some([0.1, 1.0, 10.0].as_slice()));
}
#[test]
fn k1_gate_modes_do_not_pin_assignment_to_one() {
let ibp = SaeAssignment::from_blocks_with_mode(
array![[0.0]],
vec![array![[0.0]]],
AssignmentMode::ibp_map(1.0, 1.0, false),
)
.unwrap();
assert_abs_diff_eq!(ibp.try_assignments_row(0).unwrap()[0], 0.5, epsilon = 1e-9);
let jr = SaeAssignment::from_blocks_with_mode(
array![[-1.0]],
vec![array![[0.0]]],
AssignmentMode::jumprelu(1.0, 0.0),
)
.unwrap();
assert_abs_diff_eq!(jr.try_assignments_row(0).unwrap()[0], 0.0, epsilon = 1e-12);
let sm = SaeAssignment::from_blocks_with_mode(
Array2::<f64>::zeros((1, 1)),
vec![array![[0.0]]],
AssignmentMode::softmax(1.0),
)
.unwrap();
assert_abs_diff_eq!(sm.try_assignments_row(0).unwrap()[0], 1.0, epsilon = 1e-12);
}
#[test]
fn jumprelu_surrogate_is_centered_at_threshold() {
let threshold = 2.0;
let temperature = 1.0;
let logits = array![2.0 + 1e-6, 1.0];
let gates = jumprelu_row(logits.view(), temperature, threshold);
assert_abs_diff_eq!(gates[0], 0.5, epsilon = 1e-3);
assert!(
gates[0] < 0.6,
"surrogate not centered at threshold: {}",
gates[0]
);
assert_abs_diff_eq!(gates[1], 0.0, epsilon = 1e-12);
}
fn periodic_basis(coords: &Array2<f64>) -> (Array2<f64>, Array3<f64>) {
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 3));
let mut jet = Array3::<f64>::zeros((n, 3, 1));
for row in 0..n {
let x = coords[[row, 0]].rem_euclid(1.0);
let angle = 2.0 * std::f64::consts::PI * x;
phi[[row, 0]] = 1.0;
phi[[row, 1]] = angle.sin();
phi[[row, 2]] = angle.cos();
jet[[row, 1, 0]] = 2.0 * std::f64::consts::PI * angle.cos();
jet[[row, 2, 0]] = -2.0 * std::f64::consts::PI * angle.sin();
}
(phi, jet)
}
#[test]
fn ard_axis_prior_periodic_is_continuous_across_cut() {
let alpha = 2.3_f64;
let period = 1.0_f64;
let eps = 1.0e-6;
let below = ArdAxisPrior::eval(alpha, period - eps, Some(period));
let above = ArdAxisPrior::eval(alpha, period + eps, Some(period));
let at_zero = ArdAxisPrior::eval(alpha, 0.0, Some(period));
let cont_tol = 10.0 * alpha * eps; assert!((below.value - above.value).abs() < cont_tol);
assert!((below.grad - above.grad).abs() < cont_tol);
assert!((below.hess - above.hess).abs() < cont_tol);
assert!(below.grad.abs() < cont_tol);
assert!(above.grad.abs() < cont_tol);
assert_abs_diff_eq!(below.value, at_zero.value, epsilon = 1.0e-9);
assert_abs_diff_eq!(at_zero.value, 0.0, epsilon = 1.0e-12);
assert_abs_diff_eq!(at_zero.grad, 0.0, epsilon = 1.0e-12);
assert_abs_diff_eq!(at_zero.hess, alpha, epsilon = 1.0e-12);
let sq_a = ArdAxisPrior::eval(1.0, 0.3, Some(period)).sq_equiv;
let sq_b = ArdAxisPrior::eval(5.0, 0.3, Some(period)).sq_equiv;
assert_abs_diff_eq!(sq_a, sq_b, epsilon = 1.0e-12);
let p = ArdAxisPrior::eval(alpha, 0.3, Some(period));
assert_abs_diff_eq!(0.5 * alpha * p.sq_equiv, p.value, epsilon = 1.0e-12);
}
#[test]
fn ard_axis_prior_value_grad_fd_consistent() {
let alpha = 1.7_f64;
let h = 1.0e-6;
for &period in &[None, Some(1.0_f64), Some(std::f64::consts::TAU)] {
for &t in &[-0.37_f64, 0.02, 0.49, 0.83, 0.999, 1.4] {
let p = ArdAxisPrior::eval(alpha, t, period);
let vp = ArdAxisPrior::eval(alpha, t + h, period).value;
let vm = ArdAxisPrior::eval(alpha, t - h, period).value;
let fd_grad = (vp - vm) / (2.0 * h);
assert_abs_diff_eq!(p.grad, fd_grad, epsilon = 1.0e-5);
let gp = ArdAxisPrior::eval(alpha, t + h, period).grad;
let gm = ArdAxisPrior::eval(alpha, t - h, period).grad;
let fd_hess = (gp - gm) / (2.0 * h);
assert_abs_diff_eq!(p.hess, fd_hess, epsilon = 1.0e-5);
}
}
}
#[test]
fn axis_periods_map_each_topology() {
assert_eq!(LatentManifold::Euclidean.axis_periods(), vec![None]);
assert_eq!(
LatentManifold::Circle { period: 1.0 }.axis_periods(),
vec![Some(1.0)]
);
let torus = LatentManifold::Product(vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
]);
assert_eq!(torus.axis_periods(), vec![Some(1.0), Some(1.0)]);
let sphere_chart = LatentManifold::Product(vec![
LatentManifold::Interval { lo: -1.0, hi: 1.0 },
LatentManifold::Circle {
period: std::f64::consts::TAU,
},
]);
assert_eq!(
sphere_chart.axis_periods(),
vec![None, Some(std::f64::consts::TAU)]
);
assert_eq!(
LatentManifold::Sphere { dim: 3 }.axis_periods(),
vec![None, None, None]
);
}
#[test]
fn ard_value_continuous_across_periodic_cut_d1() {
let coords0 = array![[0.999_f64]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((1, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.1_f64]];
let alpha = 50.0_f64;
let rho = SaeManifoldRho::new(0.0, 0.0, vec![array![alpha.ln()]]);
let ard_before = term.loss(target.view(), &rho).unwrap().ard;
let q = term.assignment.row_block_dim();
let beta_dim = term.beta_dim();
let mut delta_ext = Array1::<f64>::zeros(q);
delta_ext[q - 1] = 0.002;
let delta_beta = Array1::<f64>::zeros(beta_dim);
term.apply_newton_step(delta_ext.view(), delta_beta.view(), 1.0)
.unwrap();
let wrapped = term.assignment.coords[0].row(0)[0];
assert!(
wrapped < 0.01,
"coordinate should have wrapped across the cut, got {wrapped}"
);
let ard_after = term.loss(target.view(), &rho).unwrap().ard;
assert!(
(ard_after - ard_before).abs() < 1.0e-2,
"periodic ARD jumped across the cut: before={ard_before}, after={ard_after}"
);
}
#[test]
fn penalized_objective_continuous_across_periodic_cut_with_registry_ard() {
let coords0 = array![[0.999_f64]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((1, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.1_f64]];
let alpha = 50.0_f64;
let rho = SaeManifoldRho::new(0.0, 0.0, vec![array![alpha.ln()]]);
let coord = &term.assignment.coords[0];
let mut registry = AnalyticPenaltyRegistry::new();
let ard_pen = ARDPenalty::new(
PsiSlice::full(coord.len(), Some(coord.latent_dim())),
coord.latent_dim(),
);
registry.push(AnalyticPenaltyKind::Ard(Arc::new(ard_pen)));
let obj_before = term
.penalized_objective_total(target.view(), &rho, Some(®istry), 1.0)
.unwrap();
let q = term.assignment.row_block_dim();
let beta_dim = term.beta_dim();
let mut delta_ext = Array1::<f64>::zeros(q);
delta_ext[q - 1] = 0.002; let delta_beta = Array1::<f64>::zeros(beta_dim);
term.apply_newton_step(delta_ext.view(), delta_beta.view(), 1.0)
.unwrap();
let wrapped = term.assignment.coords[0].row(0)[0];
assert!(
wrapped < 0.01,
"coordinate should have wrapped across the cut, got {wrapped}"
);
let obj_after = term
.penalized_objective_total(target.view(), &rho, Some(®istry), 1.0)
.unwrap();
assert!(
(obj_after - obj_before).abs() < 1.0e-2,
"line-search objective jumped across the cut: before={obj_before}, after={obj_after}"
);
}
#[test]
fn scad_coord_penalty_inert_and_continuous_on_periodic_axis() {
use crate::terms::analytic_penalties::{PenaltyConcavity, ScadMcpPenalty};
let coords0 = array![[0.999_f64]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((1, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.1_f64]];
let rho = SaeManifoldRho::new(0.0, 0.0, vec![array![0.0_f64]]);
let coord = &term.assignment.coords[0];
let mut registry = AnalyticPenaltyRegistry::new();
let scad = ScadMcpPenalty::new(
PsiSlice::full(coord.len(), Some(coord.latent_dim())),
5.0,
coord.n_obs(),
3.7,
1.0e-3,
PenaltyConcavity::Scad,
false,
)
.unwrap();
registry.push(AnalyticPenaltyKind::ScadMcp(Arc::new(scad)));
let with_scad = term
.penalized_objective_total(target.view(), &rho, Some(®istry), 1.0)
.unwrap();
let without = term
.penalized_objective_total(target.view(), &rho, None, 1.0)
.unwrap();
assert!(
(with_scad - without).abs() < 1.0e-12,
"SCAD coord penalty must be inert on a pure periodic axis: \
with={with_scad}, without={without}"
);
let obj_before = with_scad;
let q = term.assignment.row_block_dim();
let beta_dim = term.beta_dim();
let mut delta_ext = Array1::<f64>::zeros(q);
delta_ext[q - 1] = 0.002;
let delta_beta = Array1::<f64>::zeros(beta_dim);
term.apply_newton_step(delta_ext.view(), delta_beta.view(), 1.0)
.unwrap();
let wrapped = term.assignment.coords[0].row(0)[0];
assert!(
wrapped < 0.01,
"coordinate should have wrapped across the cut, got {wrapped}"
);
let obj_after = term
.penalized_objective_total(target.view(), &rho, Some(®istry), 1.0)
.unwrap();
assert!(
(obj_after - obj_before).abs() < 1.0e-2,
"SCAD line-search objective jumped across the periodic cut: \
before={obj_before}, after={obj_after}"
);
}
#[test]
fn scad_coord_penalty_active_on_euclidean_axis() {
let euclid = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![array![[0.5_f64], [-0.7], [1.3]]],
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(0.7),
)
.unwrap();
assert!(
sae_coord_penalty_euclidean_restriction(&euclid.coords[0]).is_none(),
"Euclidean coord must not be restricted"
);
let circle = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![array![[0.1_f64], [0.4], [0.9]]],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let (axes, compacted) = sae_coord_penalty_euclidean_restriction(&circle.coords[0])
.expect("periodic coord must be restricted");
assert!(
axes.is_empty(),
"circle has no Euclidean axes, got {axes:?}"
);
assert_eq!(compacted.len(), 0, "compacted target must be empty");
}
#[test]
fn periodic_ard_curvature_is_psd_in_assembled_htt() {
let coords0 = array![[0.40_f64], [0.60_f64]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((2, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.1_f64], [0.2_f64]];
let alpha = 100.0_f64;
let rho = SaeManifoldRho::new(0.0, 0.0, vec![array![alpha.ln()]]);
let sys = term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
for (row_idx, row) in sys.rows.iter().enumerate() {
let d = row.htt.nrows();
for a in 0..d {
assert!(
row.htt[[a, a]] >= 0.0,
"row {row_idx} htt diagonal[{a}]={} must be PSD (von-Mises \
curvature clamped to its positive part)",
row.htt[[a, a]]
);
}
}
}
#[test]
fn snapshot_restore_round_trips_mutated_state() {
let coords0 = array![[0.05], [0.20], [0.55], [0.80]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((4, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let snapshot = term.snapshot_mutable_state();
let pre_basis = term.atoms[0].basis_values.clone();
let pre_jet = term.atoms[0].basis_jacobian.clone();
let pre_decoder = term.atoms[0].decoder_coefficients.clone();
let pre_logits = term.assignment.logits.clone();
let pre_coords = term.assignment.coords[0].as_matrix();
let q = term.assignment.row_block_dim();
let beta_dim = term.beta_dim();
let delta_ext = Array1::<f64>::from_elem(4 * q, 0.3);
let delta_beta = Array1::<f64>::from_elem(beta_dim, -0.4);
term.apply_newton_step(delta_ext.view(), delta_beta.view(), 1.0)
.unwrap();
assert!(
(&term.atoms[0].basis_values - &pre_basis)
.mapv(f64::abs)
.sum()
> 1e-9
|| (&term.atoms[0].decoder_coefficients - &pre_decoder)
.mapv(f64::abs)
.sum()
> 1e-9,
"apply_newton_step did not perturb the snapshotted state"
);
term.restore_mutable_state(&snapshot);
assert_eq!(term.atoms[0].basis_values, pre_basis);
assert_eq!(term.atoms[0].basis_jacobian, pre_jet);
assert_eq!(term.atoms[0].decoder_coefficients, pre_decoder);
assert_eq!(term.assignment.logits, pre_logits);
assert_eq!(term.assignment.coords[0].as_matrix(), pre_coords);
}
#[test]
fn ibp_path_refreshes_periodic_basis_for_two_newton_iterations() {
let coords0 = array![[0.05], [0.20], [0.55], [0.80]];
let (phi0, jet0) = periodic_basis(&coords0);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.2], [-0.3], [0.4]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((4, 1)),
vec![coords0],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.10], [0.05], [-0.15], [0.20]];
let mut rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1)]);
let loss0 = term.loss(target.view(), &rho).unwrap().total();
let basis0 = term.atoms[0].basis_values.clone();
let loss = term
.run_joint_fit_arrow_schur(target.view(), &mut rho, None, 2, 0.05, 1.0e-3, 1.0e-3)
.unwrap();
assert!(loss.total().is_finite());
assert!(loss.total() <= loss0 + 1.0e-8);
assert!(
term.assignment.coords[0]
.as_flat()
.iter()
.all(|v| v.is_finite())
);
assert!(term.assignment.assignments().iter().all(|v| v.is_finite()));
let basis_delta = (&term.atoms[0].basis_values - &basis0).mapv(f64::abs).sum();
assert!(basis_delta > 1.0e-10);
}
fn small_two_atom_periodic_term() -> (SaeManifoldTerm, Array2<f64>, SaeManifoldRho) {
let coords0 = array![[0.05], [0.20], [0.55], [0.80], [0.35]];
let coords1 = array![[0.15], [0.30], [0.65], [0.90], [0.45]];
let (phi0, jet0) = periodic_basis(&coords0);
let (phi1, jet1) = periodic_basis(&coords1);
let atom0 = SaeManifoldAtom::new(
"periodic0",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
array![[0.25], [-0.35], [0.15]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let atom1 = SaeManifoldAtom::new(
"periodic1",
SaeAtomBasisKind::Periodic,
1,
phi1,
jet1,
array![[-0.10], [0.20], [0.30]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let logits = array![
[0.7, -0.2],
[0.1, 0.4],
[-0.3, 0.5],
[0.6, -0.1],
[0.2, 0.3]
];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![coords0, coords1],
vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
],
AssignmentMode::softmax(0.8),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom0, atom1], assignment).unwrap();
let target = array![[0.12], [-0.03], [0.08], [0.20], [-0.11]];
let rho = SaeManifoldRho::new(
(-0.3_f64).exp().ln(),
0.7_f64.ln(),
vec![array![0.9_f64.ln()], array![1.1_f64.ln()]],
);
(term, target, rho)
}
#[test]
fn assignment_logit_step_cap_bounds_single_iteration_gate_motion() {
let (mut term, _target, _rho) = small_two_atom_periodic_term();
let n = term.assignment.n_obs();
let q = term.assignment.row_block_dim();
let diff_before = term.assignment.logits[[0, 0]] - term.assignment.logits[[0, 1]];
let mut delta = Array1::<f64>::zeros(n * q);
delta[0] = 1.0e6;
let delta_beta = Array1::<f64>::zeros(term.beta_dim());
term.apply_newton_step(delta.view(), delta_beta.view(), 1.0)
.expect("step applies");
let cap = SAE_ASSIGNMENT_LOGIT_STEP_CAP_TAUS * term.assignment.mode.temperature();
let diff_after = term.assignment.logits[[0, 0]] - term.assignment.logits[[0, 1]];
assert!(
((diff_after - diff_before) - cap).abs() < 1.0e-9,
"a 1e6 raw logit delta must realise exactly the {cap}-cap, moved {}",
diff_after - diff_before
);
}
#[test]
fn active_mass_guard_reseeds_once_then_records_terminal_collapse() {
let (mut term, _target, _rho) = small_two_atom_periodic_term();
let n = term.assignment.n_obs();
let slam = |term: &mut SaeManifoldTerm| {
for row in 0..n {
term.assignment.logits[[row, 0]] = 0.0;
term.assignment.logits[[row, 1]] = -1.0e3;
}
};
slam(&mut term);
term.enforce_active_mass_guard(0).expect("guard runs");
assert_eq!(term.collapse_events().len(), 1);
let ev = term.collapse_events()[0];
assert_eq!(ev.atom, 1);
assert_eq!(ev.action, CollapseAction::Reseeded);
assert!(ev.max_active_mass < ev.floor);
let masses = term.assignment.assignments();
let max1 = (0..n).map(|r| masses[[r, 1]]).fold(0.0_f64, f64::max);
assert!(max1 > SAE_ATOM_ACTIVE_MASS_FLOOR);
term.enforce_active_mass_guard(1).expect("guard runs");
assert_eq!(term.collapse_events().len(), 1);
slam(&mut term);
term.enforce_active_mass_guard(2).expect("guard runs");
term.enforce_active_mass_guard(3).expect("guard runs");
let terminals: Vec<_> = term
.collapse_events()
.iter()
.filter(|e| e.action == CollapseAction::Terminal)
.collect();
assert_eq!(terminals.len(), 1);
assert_eq!(terminals[0].atom, 1);
assert!(
term.collapse_events().iter().all(|e| e.atom == 1),
"the healthy atom must never be flagged"
);
}
#[test]
fn streaming_exact_reml_matches_full_batch_reml_small_sae() {
let (term0, target, rho) = small_two_atom_periodic_term();
let mut full = term0.clone();
let mut streaming = term0;
let (full_cost, full_loss, _cache) = full
.reml_criterion_with_cache(target.view(), &rho, None, 2, 0.25, 1.0e-4, 1.0e-4)
.unwrap();
let (stream_cost, stream_loss) = streaming
.reml_criterion_streaming_exact(target.view(), &rho, None, 2, 0.25, 1.0e-4, 1.0e-4)
.unwrap();
assert_abs_diff_eq!(stream_cost, full_cost, epsilon = 1.0e-8);
assert_abs_diff_eq!(stream_loss.total(), full_loss.total(), epsilon = 1.0e-8);
}
#[test]
fn reml_retries_refinement_after_non_pd_undamped_evidence_factor() {
let (mut term0, target, rho) = small_two_atom_periodic_term();
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let cold_sys = term0
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let cold_factor = solve_arrow_newton_step_with_options(&cold_sys, 0.0, 0.0, &options);
let cold_err = match cold_factor {
Err(err) => err,
Ok(_) => panic!("fixture must start with a non-PD undamped evidence row factor"),
};
assert!(
SaeManifoldTerm::is_undamped_evidence_row_non_pd(&cold_err),
"fixture must start with a genuine evidence-mode non-PD row factor; got {cold_err}",
);
let mut full = term0.clone();
let mut streaming = term0;
let (full_cost, full_loss, cache) = full
.reml_criterion_with_cache(target.view(), &rho, None, 1, 0.25, 1.0e-4, 1.0e-4)
.expect("dense REML must refine through the cold non-PD evidence factor");
let log_det = arrow_log_det_from_cache(&cache).expect("refined cache must carry log-det");
assert!(full_cost.is_finite());
assert!(full_loss.total().is_finite());
assert!(log_det.is_finite());
let (stream_cost, stream_loss) = streaming
.reml_criterion_streaming_exact(target.view(), &rho, None, 1, 0.25, 1.0e-4, 1.0e-4)
.expect("streaming REML must share the dense refinement retry");
assert_abs_diff_eq!(stream_cost, full_cost, epsilon = 1.0e-8);
assert_abs_diff_eq!(stream_loss.total(), full_loss.total(), epsilon = 1.0e-8);
}
#[test]
fn reconstruction_dispersion_uses_ard_shrunk_coordinate_edf() {
let n = 24usize;
let p = 2usize;
let coords = Array2::from_shape_fn((n, 1), |(row, _)| (row as f64 + 0.25) / n as f64);
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
array![[0.30, -0.10], [0.20, 0.40], [-0.35, 0.15]],
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(Arc::new(TestPeriodicEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = Array2::from_shape_fn((n, p), |(row, col)| {
let x = (row as f64 + 0.5) / n as f64;
if col == 0 {
0.45 * (std::f64::consts::TAU * x).sin() + 0.07
} else {
-0.20 * (std::f64::consts::TAU * x).cos() + 0.03 * row as f64
}
});
let alpha = 250.0_f64;
let rho = SaeManifoldRho::new(0.0, 0.8_f64.ln(), vec![array![alpha.ln()]]);
let loss = term.loss(target.view(), &rho).unwrap();
let sys = term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let (_delta_t, _delta_beta, cache) =
solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, &options).unwrap();
let dispersion = term.reconstruction_dispersion(&loss, &cache, &rho).unwrap();
let smooth_edf = term
.decoder_smoothness_effective_dof(&cache, rho.lambda_smooth())
.unwrap();
let beta_edf = (term.beta_dim() as f64 - smooth_edf).max(0.0);
let traces = term.ard_inverse_traces(&cache).unwrap();
let coord_edf = (n as f64 - alpha * traces[0][0]).clamp(0.0, n as f64);
let rss = 2.0 * loss.data_fit;
let expected = rss / ((n * p) as f64 - beta_edf - coord_edf).max(1.0);
assert_abs_diff_eq!(dispersion, expected, epsilon = 1.0e-10);
let old_full_coordinate_edf = n as f64;
let old_full_coordinate_dispersion =
rss / ((n * p) as f64 - beta_edf - old_full_coordinate_edf).max(1.0);
assert!(
coord_edf < 0.25 * old_full_coordinate_edf,
"test setup must put the coordinate axis in an ARD-shrunk regime; \
coord_edf={coord_edf}, old_full_coordinate_edf={old_full_coordinate_edf}"
);
assert!(
dispersion < 0.75 * old_full_coordinate_dispersion,
"φ̂ must use the ARD-shrunk coordinate edf, not the old full \
coordinate count: got {dispersion}, old formula {old_full_coordinate_dispersion}"
);
}
#[test]
fn streaming_plan_routes_by_memory_budget_with_identical_logdet() {
let (term0, target, rho) = small_two_atom_periodic_term();
let total_basis: usize = term0.atoms.iter().map(|atom| atom.basis_size()).sum();
let d_max = term0
.atoms
.iter()
.map(|atom| atom.latent_dim)
.max()
.unwrap();
let dense_plan = sae_streaming_plan_from_budget(
term0.n_obs(),
total_basis,
term0.k_atoms(),
d_max,
usize::MAX / 4,
1024 * 1024,
);
assert!(!dense_plan.streaming);
let streaming_plan = sae_streaming_plan_from_budget(
term0.n_obs(),
total_basis,
term0.k_atoms(),
d_max,
1,
512,
);
assert!(streaming_plan.streaming);
let mut full = term0.clone();
full.reml_criterion_with_cache(target.view(), &rho, None, 2, 0.25, 1.0e-4, 1.0e-4)
.unwrap();
let sys = full
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let factor_result = solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, &options).unwrap();
let full_logdet = arrow_log_det_from_cache(&factor_result.2).unwrap();
let mut streaming = StreamingArrowSchur::from_system(&sys, streaming_plan.chunk_size);
let streaming_logdet = streaming.exact_arrow_log_det(0.0, 0.0, &options).unwrap();
assert_abs_diff_eq!(streaming_logdet, full_logdet, epsilon = 1.0e-8);
}
#[test]
fn sparse_active_layout_work_scales_with_active_atoms_not_total_k() {
let n = 3;
let k_atoms = 100_000;
let mut active_rows = Vec::with_capacity(n);
for row in 0..n {
active_rows.push(vec![row, 10_000 + row, 90_000 + row]);
}
let coord_dims = vec![1usize; k_atoms];
let coord_offsets_full: Vec<usize> = (0..k_atoms).map(|k| k_atoms + k).collect();
let layout = SaeRowLayout::from_active_atoms(active_rows, coord_dims, coord_offsets_full);
for row in 0..n {
assert_eq!(layout.active_atoms[row].len(), 3);
assert_eq!(layout.row_q_active(row), 6);
}
let compact_work: usize = (0..n)
.map(|row| {
let q = layout.row_q_active(row);
q * q
})
.sum();
let dense_q = 2 * k_atoms;
let dense_work = n * dense_q * dense_q;
assert!(compact_work < dense_work / 1_000_000_000);
assert_eq!(compact_work, n * 36);
}
#[test]
fn run_joint_fit_arrow_schur_escalates_ridge_on_non_pd_row_block() {
let coords = array![[0.1], [0.4], [0.7]];
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
array![[0.05], [-0.05], [0.05]],
Array2::<f64>::zeros((3, 3)),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.20], [-0.10], [0.45]];
let mut rho = SaeManifoldRho::new(0.0, -20.0, vec![Array1::<f64>::zeros(1)]);
let result =
term.run_joint_fit_arrow_schur(target.view(), &mut rho, None, 1, 1.0, 1.0e-6, 1.0e-6);
assert!(
result.is_ok(),
"run_joint_fit_arrow_schur should recover from degenerate H_tt via LM ridge escalation; got: {result:?}",
);
}
#[test]
fn solve_newton_step_escalates_ridge_on_non_pd_row_block() {
let coords = array![[0.1], [0.4], [0.7]];
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
array![[0.05], [-0.05], [0.05]],
Array2::<f64>::zeros((3, 3)),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.20], [-0.10], [0.45]];
let rho = SaeManifoldRho::new(0.0, -20.0, vec![Array1::<f64>::zeros(1)]);
let result = term.solve_newton_step(target.view(), &rho, None, 1.0e-6, 1.0e-6);
assert!(
result.is_ok(),
"solve_newton_step should recover from degenerate H_tt via LM ridge escalation; got: {result:?}",
);
}
#[test]
fn sae_arrow_schur_beta_quadratic_model_matches_penalized_loss_change() {
let coords = array![[0.10], [0.35], [0.80]];
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
array![[0.65], [-0.45], [0.25]],
array![[3.0, 0.4, -0.2], [0.1, 2.5, 0.3], [-0.5, 0.2, 1.8]],
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode(
Array2::<f64>::zeros((3, 1)),
vec![coords],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.20], [-0.10], [0.45]];
let rho = SaeManifoldRho::new(0.0, 1.3_f64.ln(), vec![array![0.9_f64.ln()]]);
let sys = term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let beta0 = term.flatten_beta();
let loss0 = term.loss(target.view(), &rho).unwrap().total();
let mut direction = sys.gb.mapv(|v| -v);
let direction_norm = direction.iter().map(|v| v * v).sum::<f64>().sqrt();
assert!(direction_norm > 1.0e-12);
for value in direction.iter_mut() {
*value /= direction_norm;
}
let epsilon = 1.0e-3;
let delta = direction.mapv(|v| epsilon * v);
let beta_trial = beta0 + δ
term.set_flat_beta(beta_trial.view()).unwrap();
let actual = term.loss(target.view(), &rho).unwrap().total() - loss0;
let linear = sys.gb.dot(&delta);
let mut hbb_delta = Array1::<f64>::zeros(delta.len());
{
let op = sys.effective_penalty_op();
let d_slice = delta.as_slice().expect("delta is contiguous");
let hd_slice = hbb_delta.as_slice_mut().expect("hbb_delta is contiguous");
op.matvec(d_slice, hd_slice);
}
let quadratic = 0.5 * delta.dot(&hbb_delta);
let predicted = linear + quadratic;
let error = (actual - predicted).abs();
assert!(
error <= 1.0e-4,
"actual={actual:.12e}, predicted={predicted:.12e}, error={error:.12e}"
);
}
#[test]
fn sae_row_layout_from_dense_weights_top_k_and_cutoff() {
let coord_dims = vec![2usize, 1, 2];
let coord_offsets_full = vec![3usize, 5, 6];
let assignments = vec![
Array1::from_vec(vec![0.7, 0.01, 0.29]),
Array1::from_vec(vec![0.001, 0.002, 0.0005]),
];
let layout =
SaeRowLayout::from_dense_weights(&assignments, 2, 0.05, coord_dims, coord_offsets_full);
assert_eq!(layout.active_atoms[0], vec![0, 2]);
assert_eq!(layout.active_atoms[1], vec![1]);
assert_eq!(layout.row_q_active(0), 6);
assert_eq!(layout.row_q_active(1), 2);
let compact = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut full = vec![0.0_f64; 8];
layout.expand_row(0, &compact, &mut full);
assert_eq!(full[0], 1.0);
assert_eq!(full[1], 0.0);
assert_eq!(full[2], 2.0);
assert_eq!(full[3], 3.0);
assert_eq!(full[4], 4.0);
assert_eq!(full[5], 0.0);
assert_eq!(full[6], 5.0);
assert_eq!(full[7], 6.0);
}
#[test]
fn sae_mechsparsity_beta_block_routes_through_arrow_schur_gb() {
let coords = array![[0.10], [0.35], [0.80]];
let (phi, jet) = periodic_basis(&coords);
let decoder = array![
[0.7, -0.2, 0.05, 0.4],
[-0.5, 0.6, -0.1, 0.3],
[0.2, 0.0, -0.4, -0.6],
];
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder.clone(),
Array2::<f64>::eye(3),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let m = 3usize;
let p = 4usize;
let slice = PsiSlice::full(m * p, Some(m));
let penalty = MechanismSparsityPenalty::new(
slice,
vec![vec![0, 1], vec![2, 3]],
1.0,
1.0e-6,
(term.n_obs()) as f64,
false,
)
.unwrap();
let mut registry = AnalyticPenaltyRegistry::new();
registry.push(AnalyticPenaltyKind::MechanismSparsity(Arc::new(penalty)));
let target = array![
[0.20, 0.10, -0.05, 0.25],
[-0.10, 0.30, 0.15, -0.20],
[0.45, -0.05, 0.10, 0.30],
];
let rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1)]);
let sys = term
.assemble_arrow_schur(target.view(), &rho, Some(®istry))
.unwrap();
assert_eq!(sys.gb.len(), m * p, "gb should match flatten_beta length");
let mut absmax = 0.0_f64;
for v in sys.gb.iter().copied() {
assert!(v.is_finite());
if v.abs() > absmax {
absmax = v.abs();
}
}
assert!(
absmax > 1.0e-6,
"MechSparsity must inject a non-trivial gradient into the SAE arrow-Schur gb; absmax={absmax:.3e}"
);
let sys_no_penalty = term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let beta = term.flatten_beta();
let expected = {
let s = (0.5_f64.powi(2) + 0.6_f64.powi(2) + 1.0e-12).sqrt();
(2.0_f64).sqrt() * (-0.5_f64) / s
};
let delta = sys.gb[1 * p + 0] - sys_no_penalty.gb[1 * p + 0];
assert!(
(delta - expected).abs() <= 1.0e-6,
"expected MechSparsity gb contribution at (basis=1, feat=0) ≈ {expected:.6e}, \
got Δgb={delta:.6e} (gb_with={:.6e}, gb_without={:.6e}, beta entry = {})",
sys.gb[1 * p + 0],
sys_no_penalty.gb[1 * p + 0],
beta[1 * p + 0]
);
}
fn smoothed_nuclear_norm(decoder: &Array2<f64>, eps: f64) -> f64 {
let (_u, s, _vt) = decoder.clone().svd(false, false).unwrap();
s.iter()
.map(|sigma| (sigma * sigma + eps * eps).sqrt() - eps)
.sum()
}
#[test]
fn sae_nuclear_norm_beta_block_routes_through_gb_and_shrinks_spectrum() {
let coords = array![[0.10], [0.35], [0.80]];
let (phi, jet) = periodic_basis(&coords);
let decoder = array![
[0.9, -0.2, 0.05, 0.4],
[-0.5, 0.7, -0.1, 0.3],
[0.2, 0.1, -0.8, -0.6],
];
let m = 3usize;
let p = 4usize;
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder.clone(),
Array2::<f64>::eye(3),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((3, 1)),
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let eps = 1.0e-6;
let slice = PsiSlice::full(m * p, Some(m));
let penalty = NuclearNormPenalty::new(slice, 1.0, p, eps, None, false).unwrap();
let mut registry = AnalyticPenaltyRegistry::new();
registry.push(AnalyticPenaltyKind::NuclearNorm(Arc::new(penalty)));
term.validate_analytic_penalty_registry(®istry)
.expect("NuclearNorm must be accepted (redirected to the β block)");
let target = array![
[0.20, 0.10, -0.05, 0.25],
[-0.10, 0.30, 0.15, -0.20],
[0.45, -0.05, 0.10, 0.30],
];
let rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1)]);
let baseline = term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let sys = term
.assemble_arrow_schur(target.view(), &rho, Some(®istry))
.unwrap();
assert_eq!(sys.gb.len(), m * p, "gb should match flatten_beta length");
assert_eq!(
baseline.gb.len(),
m * p,
"baseline gb should match flatten_beta length"
);
let mut absmax = 0.0_f64;
let mut penalty_grad = Array1::<f64>::zeros(m * p);
for ((dst, sys_g), baseline_g) in penalty_grad
.iter_mut()
.zip(sys.gb.iter())
.zip(baseline.gb.iter())
{
let v = *sys_g - *baseline_g;
assert!(v.is_finite());
*dst = v;
absmax = absmax.max(v.abs());
}
assert!(
absmax > 1.0e-6,
"NuclearNorm must inject a non-trivial gradient into the SAE \
arrow-Schur gb; absmax={absmax:.3e}"
);
let per_atom = NuclearNormPenalty::new(
PsiSlice {
range: 0..m * p,
latent_dim: Some(p),
},
1.0,
m,
eps,
None,
false,
)
.unwrap();
let beta = term.flatten_beta();
let ref_grad = per_atom.grad_target(beta.view(), Array1::<f64>::zeros(0).view());
for j in 0..m * p {
assert!(
(penalty_grad[j] - ref_grad[j]).abs() <= 1.0e-9,
"penalty gb[{j}]={:.12e} must equal analytic spectral grad {:.12e}",
penalty_grad[j],
ref_grad[j]
);
}
let base_norm = smoothed_nuclear_norm(&decoder, eps);
let step = 1.0e-2;
let mut shrunk = decoder.clone();
for ((row, feat), value) in shrunk.indexed_iter_mut() {
*value -= step * penalty_grad[row * p + feat];
}
let shrunk_norm = smoothed_nuclear_norm(&shrunk, eps);
assert!(
shrunk_norm < base_norm,
"a step along gb must shrink the decoder spectrum: \
before={base_norm:.9e}, after={shrunk_norm:.9e}"
);
assert!(sys.hbb.is_empty());
let mut hbb_diag = vec![0.0_f64; m * p];
sys.effective_penalty_op().diagonal(&mut hbb_diag);
for i in 0..m * p {
assert!(
hbb_diag[i] >= -1.0e-9,
"hbb diagonal must be non-negative (PSD majorizer); hbb[{i},{i}]={:.3e}",
hbb_diag[i]
);
}
}
#[derive(Debug)]
struct TestPeriodicEvaluator;
impl SaeBasisEvaluator for TestPeriodicEvaluator {
fn second_jet_dyn(
&self,
coords: ArrayView2<'_, f64>,
) -> Option<Result<Array4<f64>, String>> {
if coords.ncols() != 1 {
return Some(Err(format!(
"TestPeriodicEvaluator::second_jet_dyn: expected latent_dim 1, got {}",
coords.ncols()
)));
}
None
}
fn third_jet_dyn(
&self,
coords: ArrayView2<'_, f64>,
) -> Option<Result<Array5<f64>, String>> {
if coords.ncols() != 1 {
return Some(Err(format!(
"TestPeriodicEvaluator::third_jet_dyn: expected latent_dim 1, got {}",
coords.ncols()
)));
}
None
}
fn evaluate(
&self,
coords: ArrayView2<'_, f64>,
) -> Result<(Array2<f64>, Array3<f64>), String> {
Ok(periodic_basis(&coords.to_owned()))
}
}
#[derive(Debug, Clone)]
struct SaeFdWorst {
index: usize,
analytic: f64,
finite_difference: f64,
absolute_error: f64,
relative_error: f64,
}
impl SaeFdWorst {
fn new() -> Self {
Self {
index: 0,
analytic: 0.0,
finite_difference: 0.0,
absolute_error: 0.0,
relative_error: 0.0,
}
}
fn observe(&mut self, index: usize, analytic: f64, finite_difference: f64) {
let absolute_error = (analytic - finite_difference).abs();
let scale = analytic.abs().max(finite_difference.abs()).max(1.0e-9);
let relative_error = absolute_error / scale;
if relative_error > self.relative_error {
self.index = index;
self.analytic = analytic;
self.finite_difference = finite_difference;
self.absolute_error = absolute_error;
self.relative_error = relative_error;
}
}
}
#[derive(Debug, Clone)]
struct SaeFdBlockReport {
label: String,
base_loss: f64,
coord: SaeFdWorst,
decoder: SaeFdWorst,
}
fn sae_fd_decoder(n_basis: usize, p_out: usize) -> Array2<f64> {
let mut decoder = Array2::<f64>::zeros((n_basis, p_out));
for basis in 0..n_basis {
for out_col in 0..p_out {
let phase = 0.73 * ((basis + 1) as f64) + 1.17 * ((out_col + 1) as f64);
decoder[[basis, out_col]] = 0.16 * phase.sin() + 0.05 * (1.9 * phase).cos();
}
}
decoder
}
fn sae_fd_target(n_obs: usize, p_out: usize) -> Array2<f64> {
let mut target = Array2::<f64>::zeros((n_obs, p_out));
for row in 0..n_obs {
for out_col in 0..p_out {
let x = (row as f64) + 1.0;
let y = (out_col as f64) + 1.0;
target[[row, out_col]] =
0.21 * (0.31 * x + 0.47 * y).sin() - 0.13 * (0.19 * x * y).cos();
}
}
target
}
fn sae_fd_coords(label: &str, n_obs: usize) -> Array2<f64> {
let mut coords = Array2::<f64>::zeros((n_obs, 1));
for row in 0..n_obs {
let x = row as f64;
coords[[row, 0]] = match label {
"periodic_d1" => 0.07 + 0.043 * x + 0.004 * (1.3 * x).sin(),
"euclidean_d1" => -0.46 + 0.048 * x + 0.006 * (1.7 * x).cos(),
other => panic!("unknown SAE FD case label {other}"),
};
}
coords
}
fn sae_fd_term(label: &str) -> (SaeManifoldTerm, Array2<f64>, SaeManifoldRho) {
let n_obs = 20usize;
let p_out = 3usize;
let coords = sae_fd_coords(label, n_obs);
let (basis_kind, phi, jet, n_basis, atom) = match label {
"periodic_d1" => {
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let n_basis = phi.ncols();
let atom = SaeManifoldAtom::new(
"periodic_d1",
SaeAtomBasisKind::Periodic,
1,
phi.clone(),
jet.clone(),
sae_fd_decoder(n_basis, p_out),
Array2::<f64>::eye(n_basis),
)
.unwrap()
.with_basis_second_jet(evaluator);
(SaeAtomBasisKind::Periodic, phi, jet, n_basis, atom)
}
"euclidean_d1" => {
let evaluator = Arc::new(EuclideanPatchEvaluator::new(1, 2).unwrap());
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let n_basis = phi.ncols();
let atom = SaeManifoldAtom::new(
"euclidean_d1",
SaeAtomBasisKind::EuclideanPatch,
1,
phi.clone(),
jet.clone(),
sae_fd_decoder(n_basis, p_out),
Array2::<f64>::eye(n_basis),
)
.unwrap()
.with_basis_second_jet(evaluator);
(SaeAtomBasisKind::EuclideanPatch, phi, jet, n_basis, atom)
}
other => panic!("unknown SAE FD case label {other}"),
};
assert_eq!(
basis_kind.latent_manifold(1),
atom.basis_kind.latent_manifold(1)
);
assert_eq!(phi.dim(), (n_obs, n_basis));
assert_eq!(jet.dim(), (n_obs, n_basis, 1));
let manifold = atom.basis_kind.latent_manifold(1);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n_obs, 1)),
vec![coords],
vec![manifold],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = sae_fd_target(n_obs, p_out);
let rho = SaeManifoldRho::new(0.0, 1.0e-4_f64.ln(), vec![array![-30.0]]);
(term, target, rho)
}
fn sae_fd_refresh(term: &mut SaeManifoldTerm) {
let coords = term.assignment.coords[0].as_matrix();
term.atoms[0].refresh_basis(coords.view()).unwrap();
}
fn sae_fd_set_coord(term: &mut SaeManifoldTerm, row: usize, value: f64) {
let mut flat = term.assignment.coords[0].as_flat().clone();
flat[row] = value;
term.assignment.coords[0].set_flat(flat.view());
sae_fd_refresh(term);
}
fn sae_fd_total_loss(
term: &SaeManifoldTerm,
target: &Array2<f64>,
rho: &SaeManifoldRho,
) -> f64 {
term.loss(target.view(), rho).unwrap().total()
}
fn sae_fd_check_case(label: &str) -> SaeFdBlockReport {
let epsilon = 1.0e-6;
let (term, target, rho) = sae_fd_term(label);
let base_loss = sae_fd_total_loss(&term, &target, &rho);
assert!(base_loss.is_finite(), "{label}: base loss is not finite");
let mut assembled = term.clone();
sae_fd_refresh(&mut assembled);
let sys = assembled
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
assert_eq!(sys.rows.len(), term.n_obs());
assert_eq!(sys.gb.len(), term.beta_dim());
for row in 0..term.n_obs() {
assert_eq!(
sys.rows[row].gt.len(),
1,
"{label}: K=1 softmax d=1 should expose exactly one row coordinate gradient"
);
}
let mut coord = SaeFdWorst::new();
let base_coords = term.assignment.coords[0].as_flat().clone();
for row in 0..term.n_obs() {
let mut plus = term.clone();
sae_fd_set_coord(&mut plus, row, base_coords[row] + epsilon);
let loss_plus = sae_fd_total_loss(&plus, &target, &rho);
let mut minus = term.clone();
sae_fd_set_coord(&mut minus, row, base_coords[row] - epsilon);
let loss_minus = sae_fd_total_loss(&minus, &target, &rho);
let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);
coord.observe(row, sys.rows[row].gt[0], finite_difference);
}
let mut decoder = SaeFdWorst::new();
let beta = term.flatten_beta();
for beta_idx in 0..beta.len() {
let mut beta_plus = beta.clone();
beta_plus[beta_idx] += epsilon;
let mut plus = term.clone();
plus.set_flat_beta(beta_plus.view()).unwrap();
sae_fd_refresh(&mut plus);
let loss_plus = sae_fd_total_loss(&plus, &target, &rho);
let mut beta_minus = beta.clone();
beta_minus[beta_idx] -= epsilon;
let mut minus = term.clone();
minus.set_flat_beta(beta_minus.view()).unwrap();
sae_fd_refresh(&mut minus);
let loss_minus = sae_fd_total_loss(&minus, &target, &rho);
let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);
decoder.observe(beta_idx, sys.gb[beta_idx], finite_difference);
}
SaeFdBlockReport {
label: label.to_string(),
base_loss,
coord,
decoder,
}
}
#[derive(Clone, Copy)]
enum SaePenCaseKind {
EuclideanD1,
PeriodicD1,
EuclideanD2,
}
#[derive(Clone, Copy)]
enum SaePenKind {
Isometry,
Ard,
ScadMcp,
NuclearNorm,
DecoderIncoherence,
}
fn sae_pen_term(
kind: SaePenCaseKind,
) -> (SaeManifoldTerm, Array2<f64>, SaeManifoldRho, PsiSlice) {
let n_obs = 12usize;
let p_out = 3usize;
let (coords, latent_dim, atom): (Array2<f64>, usize, SaeManifoldAtom) = match kind {
SaePenCaseKind::PeriodicD1 => {
let mut coords = Array2::<f64>::zeros((n_obs, 1));
for row in 0..n_obs {
let x = row as f64;
coords[[row, 0]] = 0.11 + 0.037 * x + 0.004 * (1.3 * x).sin();
}
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let n_basis = phi.ncols();
let atom = SaeManifoldAtom::new(
"periodic_d1",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
sae_fd_decoder(n_basis, p_out),
Array2::<f64>::eye(n_basis),
)
.unwrap()
.with_basis_second_jet(evaluator);
(coords, 1, atom)
}
SaePenCaseKind::EuclideanD1 => {
let mut coords = Array2::<f64>::zeros((n_obs, 1));
for row in 0..n_obs {
let x = row as f64;
coords[[row, 0]] = -0.41 + 0.052 * x + 0.006 * (1.7 * x).cos();
}
let evaluator = Arc::new(EuclideanPatchEvaluator::new(1, 2).unwrap());
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let n_basis = phi.ncols();
let atom = SaeManifoldAtom::new(
"euclidean_d1",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
sae_fd_decoder(n_basis, p_out),
Array2::<f64>::eye(n_basis),
)
.unwrap()
.with_basis_second_jet(evaluator);
(coords, 1, atom)
}
SaePenCaseKind::EuclideanD2 => {
let mut coords = Array2::<f64>::zeros((n_obs, 2));
for row in 0..n_obs {
let x = row as f64;
coords[[row, 0]] = -0.33 + 0.041 * x + 0.005 * (1.1 * x).cos();
coords[[row, 1]] = 0.27 - 0.036 * x + 0.004 * (0.9 * x).sin();
}
let evaluator = Arc::new(EuclideanPatchEvaluator::new(2, 2).unwrap());
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let n_basis = phi.ncols();
let atom = SaeManifoldAtom::new(
"euclidean_d2",
SaeAtomBasisKind::EuclideanPatch,
2,
phi,
jet,
sae_fd_decoder(n_basis, p_out),
Array2::<f64>::eye(n_basis),
)
.unwrap()
.with_basis_second_jet(evaluator);
(coords, 2, atom)
}
};
let manifold = atom.basis_kind.latent_manifold(latent_dim);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n_obs, 1)),
vec![coords],
vec![manifold],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = sae_fd_target(n_obs, p_out);
let log_ard = vec![Array1::from_elem(latent_dim, -30.0_f64)];
let rho = SaeManifoldRho::new(0.0, 1.0e-4_f64.ln(), log_ard);
let slice = PsiSlice {
range: 0..n_obs * latent_dim,
latent_dim: Some(latent_dim),
};
(term, target, rho, slice)
}
fn sae_pen_term_k2() -> (SaeManifoldTerm, Array2<f64>, SaeManifoldRho) {
let n_obs = 12usize;
let p_out = 3usize;
let mut atoms = Vec::with_capacity(2);
let mut coord_blocks = Vec::with_capacity(2);
for atom_idx in 0..2usize {
let mut coords = Array2::<f64>::zeros((n_obs, 1));
for row in 0..n_obs {
let x = row as f64;
coords[[row, 0]] = if atom_idx == 0 {
-0.41 + 0.052 * x + 0.006 * (1.7 * x).cos()
} else {
0.18 + 0.039 * x + 0.005 * (1.1 * x).sin()
};
}
let evaluator = Arc::new(EuclideanPatchEvaluator::new(1, 2).unwrap());
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let n_basis = phi.ncols();
let mut decoder = sae_fd_decoder(n_basis, p_out);
if atom_idx == 1 {
for basis in 0..n_basis {
for out_col in 0..p_out {
decoder[[basis, out_col]] += 0.07 * ((basis + out_col) as f64 + 1.0).cos();
}
}
}
let atom = SaeManifoldAtom::new(
"euclidean_d1",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(n_basis),
)
.unwrap()
.with_basis_second_jet(evaluator);
atoms.push(atom);
coord_blocks.push(coords);
}
let manifold = LatentManifold::Euclidean;
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::from_elem((n_obs, 2), 0.2),
coord_blocks,
vec![manifold.clone(), manifold],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let target = sae_fd_target(n_obs, p_out);
let log_ard = vec![
Array1::from_elem(1, -30.0_f64),
Array1::from_elem(1, -30.0_f64),
];
let rho = SaeManifoldRho::new(0.0, 1.0e-4_f64.ln(), log_ard);
(term, target, rho)
}
fn sae_pen_registry(
pen: SaePenKind,
coord_slice: &PsiSlice,
n_obs: usize,
latent_dim: usize,
beta_len: usize,
p_out: usize,
) -> AnalyticPenaltyRegistry {
use crate::terms::analytic_penalties::PenaltyConcavity;
use crate::terms::analytic_penalties::ScadMcpPenalty;
let mut registry = AnalyticPenaltyRegistry::new();
match pen {
SaePenKind::Isometry => {
let penalty = IsometryPenalty::new_euclidean(coord_slice.clone(), latent_dim);
registry.push(AnalyticPenaltyKind::Isometry(Arc::new(penalty)));
}
SaePenKind::Ard => {
let penalty = ARDPenalty::new(coord_slice.clone(), latent_dim);
registry.push(AnalyticPenaltyKind::Ard(Arc::new(penalty)));
}
SaePenKind::ScadMcp => {
let penalty = ScadMcpPenalty::new(
coord_slice.clone(),
0.5,
n_obs,
3.0,
1.0e-4,
PenaltyConcavity::Mcp,
false,
)
.unwrap();
registry.push(AnalyticPenaltyKind::ScadMcp(Arc::new(penalty)));
}
SaePenKind::NuclearNorm => {
let slice = PsiSlice {
range: 0..beta_len,
latent_dim: Some(beta_len / p_out),
};
let penalty =
NuclearNormPenalty::new(slice, 0.7, p_out, 1.0e-4, None, false).unwrap();
registry.push(AnalyticPenaltyKind::NuclearNorm(Arc::new(penalty)));
}
SaePenKind::DecoderIncoherence => {
let m_per = beta_len / (2 * p_out);
let slice = PsiSlice {
range: 0..beta_len,
latent_dim: Some(beta_len / p_out),
};
let penalty = DecoderIncoherencePenalty::new(
slice,
vec![m_per, m_per],
p_out,
Array2::<f64>::from_elem((2, 2), 0.5),
0.6,
false,
)
.unwrap();
registry.push(AnalyticPenaltyKind::DecoderIncoherence(Arc::new(penalty)));
}
}
registry
}
fn sae_pen_fd_check(
label: &str,
term: &SaeManifoldTerm,
target: &Array2<f64>,
rho: &SaeManifoldRho,
registry: &AnalyticPenaltyRegistry,
) -> SaeFdBlockReport {
let epsilon = 1.0e-6;
let base_obj = term
.penalized_objective_total(target.view(), rho, Some(registry), 1.0)
.unwrap();
assert!(base_obj.is_finite(), "{label}: base objective not finite");
let mut assembled = term.clone();
let sys = assembled
.assemble_arrow_schur(target.view(), rho, Some(registry))
.unwrap();
let mut coord = SaeFdWorst::new();
let coord_offsets = term.assignment.coord_offsets();
for atom_idx in 0..term.k_atoms() {
let off = coord_offsets[atom_idx];
let d = term.assignment.coords[atom_idx].latent_dim();
let base_flat = term.assignment.coords[atom_idx].as_flat().clone();
let n_atom = base_flat.len() / d;
for row in 0..n_atom {
for axis in 0..d {
let lin = row * d + axis;
let mut plus = term.clone();
let mut flat_p = base_flat.clone();
flat_p[lin] += epsilon;
plus.assignment.coords[atom_idx].set_flat(flat_p.view());
let coords_p = plus.assignment.coords[atom_idx].as_matrix();
plus.atoms[atom_idx].refresh_basis(coords_p.view()).unwrap();
let obj_p = plus
.penalized_objective_total(target.view(), rho, Some(registry), 1.0)
.unwrap();
let mut minus = term.clone();
let mut flat_m = base_flat.clone();
flat_m[lin] -= epsilon;
minus.assignment.coords[atom_idx].set_flat(flat_m.view());
let coords_m = minus.assignment.coords[atom_idx].as_matrix();
minus.atoms[atom_idx]
.refresh_basis(coords_m.view())
.unwrap();
let obj_m = minus
.penalized_objective_total(target.view(), rho, Some(registry), 1.0)
.unwrap();
let finite_difference = (obj_p - obj_m) / (2.0 * epsilon);
coord.observe(
row * d + axis,
sys.rows[row].gt[off + axis],
finite_difference,
);
}
}
}
let mut decoder = SaeFdWorst::new();
let beta = term.flatten_beta();
for beta_idx in 0..beta.len() {
let mut beta_plus = beta.clone();
beta_plus[beta_idx] += epsilon;
let mut plus = term.clone();
plus.set_flat_beta(beta_plus.view()).unwrap();
let obj_p = plus
.penalized_objective_total(target.view(), rho, Some(registry), 1.0)
.unwrap();
let mut beta_minus = beta.clone();
beta_minus[beta_idx] -= epsilon;
let mut minus = term.clone();
minus.set_flat_beta(beta_minus.view()).unwrap();
let obj_m = minus
.penalized_objective_total(target.view(), rho, Some(registry), 1.0)
.unwrap();
let finite_difference = (obj_p - obj_m) / (2.0 * epsilon);
decoder.observe(beta_idx, sys.gb[beta_idx], finite_difference);
}
SaeFdBlockReport {
label: label.to_string(),
base_loss: base_obj,
coord,
decoder,
}
}
#[test]
fn sae_assembled_gradient_matches_penalized_objective_central_fd() {
let p_out = 3usize;
let mut reports: Vec<SaeFdBlockReport> = Vec::new();
let single_cases: &[(&str, SaePenCaseKind, SaePenKind)] = &[
(
"isometry_circle_d1",
SaePenCaseKind::PeriodicD1,
SaePenKind::Isometry,
),
(
"isometry_euclid_d2",
SaePenCaseKind::EuclideanD2,
SaePenKind::Isometry,
),
("ard_circle_d1", SaePenCaseKind::PeriodicD1, SaePenKind::Ard),
(
"scadmcp_euclid_d1",
SaePenCaseKind::EuclideanD1,
SaePenKind::ScadMcp,
),
(
"nuclearnorm_euclid_d1",
SaePenCaseKind::EuclideanD1,
SaePenKind::NuclearNorm,
),
];
for (label, case_kind, pen_kind) in single_cases {
let (term, target, rho, slice) = sae_pen_term(*case_kind);
let n_obs = term.n_obs();
let latent_dim = term.assignment.coords[0].latent_dim();
let beta_len = term.beta_dim();
let registry = sae_pen_registry(*pen_kind, &slice, n_obs, latent_dim, beta_len, p_out);
term.validate_analytic_penalty_registry(®istry)
.expect("penalty registry must validate for the SAE term");
reports.push(sae_pen_fd_check(label, &term, &target, &rho, ®istry));
}
{
let (term, target, rho) = sae_pen_term_k2();
let beta_len = term.beta_dim();
let slice = PsiSlice {
range: 0..beta_len,
latent_dim: Some(beta_len / p_out),
};
let registry = sae_pen_registry(
SaePenKind::DecoderIncoherence,
&slice,
term.n_obs(),
1,
beta_len,
p_out,
);
term.validate_analytic_penalty_registry(®istry)
.expect("DecoderIncoherence registry must validate for the K=2 SAE term");
reports.push(sae_pen_fd_check(
"decoder_incoherence_k2",
&term,
&target,
&rho,
®istry,
));
}
let relative_tolerance = 1.0e-5;
let absolute_tolerance = 1.0e-7;
let mut all_blocks_match = true;
for report in &reports {
let coord_ok = report.coord.relative_error <= relative_tolerance
|| report.coord.absolute_error <= absolute_tolerance;
let decoder_ok = report.decoder.relative_error <= relative_tolerance
|| report.decoder.absolute_error <= absolute_tolerance;
all_blocks_match = all_blocks_match && coord_ok && decoder_ok;
let coord_status = if coord_ok { "MATCH" } else { "MISMATCH" };
let decoder_status = if decoder_ok { "MATCH" } else { "MISMATCH" };
eprintln!(
"{label}: base={base:.12e}; coord_gt={coord_status} max_rel={coord_rel:.6e} \
max_abs={coord_abs:.6e} worst_row={coord_idx} analytic={coord_an:.12e} \
fd={coord_fd:.12e}; decoder_gb={decoder_status} max_rel={decoder_rel:.6e} \
max_abs={decoder_abs:.6e} worst_beta={decoder_idx} analytic={decoder_an:.12e} \
fd={decoder_fd:.12e}",
label = report.label,
base = report.base_loss,
coord_rel = report.coord.relative_error,
coord_abs = report.coord.absolute_error,
coord_idx = report.coord.index,
coord_an = report.coord.analytic,
coord_fd = report.coord.finite_difference,
decoder_rel = report.decoder.relative_error,
decoder_abs = report.decoder.absolute_error,
decoder_idx = report.decoder.index,
decoder_an = report.decoder.analytic,
decoder_fd = report.decoder.finite_difference,
);
}
assert!(
all_blocks_match,
"SAE assembled gradient does not match central FD of the penalized objective"
);
}
#[test]
fn sae_reml_extra_penalty_energy_counts_live_isometry_once() {
let p_out = 3usize;
let (term, _target, _rho, slice) = sae_pen_term(SaePenCaseKind::PeriodicD1);
let registry = sae_pen_registry(
SaePenKind::Isometry,
&slice,
term.n_obs(),
term.assignment.coords[0].latent_dim(),
term.beta_dim(),
p_out,
);
let isometry_energy = term
.isometry_penalty_value_total(®istry)
.expect("live isometry value");
assert!(
isometry_energy > 0.0,
"fixture must carry nonzero isometry energy"
);
let decoder_energy = term
.analytic_decoder_penalty_value_total(®istry)
.expect("decoder penalty value");
assert_abs_diff_eq!(decoder_energy, 0.0, epsilon = 1.0e-12);
let extra_energy = term
.reml_extra_penalty_value_total(®istry)
.expect("REML extra penalty value");
assert_abs_diff_eq!(extra_energy, isometry_energy, epsilon = 1.0e-12);
}
#[test]
fn sae_d1_assembled_gradient_matches_loss_central_fd() {
let reports = vec![
sae_fd_check_case("euclidean_d1"),
sae_fd_check_case("periodic_d1"),
];
let relative_tolerance = 3.0e-5;
let absolute_tolerance = 3.0e-7;
let mut all_blocks_match = true;
for report in &reports {
let coord_ok = report.coord.relative_error <= relative_tolerance
|| report.coord.absolute_error <= absolute_tolerance;
let decoder_ok = report.decoder.relative_error <= relative_tolerance
|| report.decoder.absolute_error <= absolute_tolerance;
all_blocks_match = all_blocks_match && coord_ok && decoder_ok;
let coord_status = if coord_ok { "MATCH" } else { "MISMATCH" };
let decoder_status = if decoder_ok { "MATCH" } else { "MISMATCH" };
let line = format!(
"{label}: base={base:.12e}; coord_gt={coord_status} max_rel={coord_rel:.6e} \
max_abs={coord_abs:.6e} worst_row={coord_idx} analytic={coord_an:.12e} \
fd={coord_fd:.12e}; decoder_gb={decoder_status} max_rel={decoder_rel:.6e} \
max_abs={decoder_abs:.6e} worst_beta={decoder_idx} analytic={decoder_an:.12e} \
fd={decoder_fd:.12e}",
label = report.label,
base = report.base_loss,
coord_rel = report.coord.relative_error,
coord_abs = report.coord.absolute_error,
coord_idx = report.coord.index,
coord_an = report.coord.analytic,
coord_fd = report.coord.finite_difference,
decoder_rel = report.decoder.relative_error,
decoder_abs = report.decoder.absolute_error,
decoder_idx = report.decoder.index,
decoder_an = report.decoder.analytic,
decoder_fd = report.decoder.finite_difference,
);
eprintln!("{line}");
}
assert!(
all_blocks_match,
"SAE d=1 assembled gradient does not match central finite difference"
);
}
fn assert_jacobian_matches_central_difference<E: SaeBasisEvaluator>(
evaluator: &E,
coords: Array2<f64>,
tolerance: f64,
) {
let epsilon = 1.0e-6;
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let (n_rows, n_basis) = phi.dim();
let latent_dim = coords.ncols();
assert_eq!(jet.dim(), (n_rows, n_basis, latent_dim));
for row in 0..n_rows {
for axis in 0..latent_dim {
let mut plus = coords.clone();
let mut minus = coords.clone();
plus[[row, axis]] += epsilon;
minus[[row, axis]] -= epsilon;
let (phi_plus, plus_jet) = evaluator.evaluate(plus.view()).unwrap();
let (phi_minus, minus_jet) = evaluator.evaluate(minus.view()).unwrap();
assert_eq!(plus_jet.dim(), jet.dim());
assert_eq!(minus_jet.dim(), jet.dim());
for basis in 0..n_basis {
let finite_difference =
(phi_plus[[row, basis]] - phi_minus[[row, basis]]) / (2.0 * epsilon);
let analytic = jet[[row, basis, axis]];
let error = (analytic - finite_difference).abs();
assert!(
error <= tolerance,
"row={row} basis={basis} axis={axis}: analytic={analytic:.12e}, \
finite_difference={finite_difference:.12e}, error={error:.12e}, \
tolerance={tolerance:.12e}"
);
}
}
}
}
#[test]
fn sae_basis_evaluator_jacobians_match_central_differences() {
assert_jacobian_matches_central_difference(
&PeriodicHarmonicEvaluator::new(7).unwrap(),
array![[-0.37], [0.0], [0.125], [0.41]],
1.0e-6,
);
assert_jacobian_matches_central_difference(
&RawPeriodicCircleEvaluator::new(3).unwrap(),
array![[-1.2, 0.3, 2.0], [0.0, -0.4, 0.8], [2.4, 1.1, -0.7]],
1.0e-6,
);
let sphere_coords = array![[-0.7, -1.2], [-0.25, 0.0], [0.35, 0.9], [0.8, 2.1]];
assert_jacobian_matches_central_difference(
&SphereChartEvaluator,
sphere_coords.clone(),
1.0e-6,
);
let (sphere_phi, sphere_jet) = SphereChartEvaluator.evaluate(sphere_coords.view()).unwrap();
assert_eq!(sphere_phi.dim(), (sphere_coords.nrows(), 7));
assert_eq!(sphere_jet.dim(), (sphere_coords.nrows(), 7, 2));
for row in 0..sphere_coords.nrows() {
let lat = sphere_coords[[row, 0]];
let lon = sphere_coords[[row, 1]];
let clat = lat.cos();
let slat = lat.sin();
let clon = lon.cos();
let slon = lon.sin();
let z = slat;
let dx_dlon = -clat * slon;
let dy_dlon = clat * clon;
assert_eq!(sphere_jet[[row, 3, 1]], 0.0);
assert!((sphere_jet[[row, 5, 1]] - dy_dlon * z).abs() <= 1.0e-12);
assert!((sphere_jet[[row, 6, 1]] - dx_dlon * z).abs() <= 1.0e-12);
}
assert_jacobian_matches_central_difference(
&AffineCoordinateEvaluator::new(3),
array![[0.0, -1.0, 2.0], [3.5, 0.25, -0.75]],
1.0e-6,
);
let torus_coords = array![[0.1, 0.7], [0.42, 0.0], [0.95, 0.33], [0.5, 0.5]];
assert_jacobian_matches_central_difference(
&TorusHarmonicEvaluator::new(2, 3).unwrap(),
torus_coords.clone(),
1.0e-6,
);
let (torus_phi, torus_jet) = TorusHarmonicEvaluator::new(2, 3)
.unwrap()
.evaluate(torus_coords.view())
.unwrap();
assert_eq!(torus_phi.dim(), (torus_coords.nrows(), 49));
assert_eq!(torus_jet.dim(), (torus_coords.nrows(), 49, 2));
for row in 0..torus_coords.nrows() {
assert!((torus_phi[[row, 0]] - 1.0).abs() <= 1.0e-12);
assert!(torus_jet[[row, 0, 0]].abs() <= 1.0e-12);
assert!(torus_jet[[row, 0, 1]].abs() <= 1.0e-12);
}
}
#[test]
fn projection_seed_grid_spans_each_compact_manifold() {
use std::f64::consts::PI;
let periodic = SaeAtomBasisKind::Periodic
.projection_seed_grid(1, 16)
.unwrap();
assert_eq!(periodic.dim(), (16, 1));
for i in 0..16 {
assert_abs_diff_eq!(periodic[[i, 0]], i as f64 / 16.0, epsilon = 1e-12);
}
assert!(periodic.iter().all(|&t| (0.0..1.0).contains(&t)));
let r = 6usize;
let sphere = SaeAtomBasisKind::Sphere.projection_seed_grid(2, r).unwrap();
assert_eq!(sphere.dim(), (r * r, 2));
for row in 0..r * r {
let lat = sphere[[row, 0]];
let lon = sphere[[row, 1]];
assert!(
lat > -PI / 2.0 && lat < PI / 2.0,
"sphere seed latitude {lat} is not strictly interior to the chart"
);
assert!(
(-PI..PI).contains(&lon),
"sphere seed longitude {lon} is outside [-π, π)"
);
}
assert!(
SaeAtomBasisKind::EuclideanPatch
.projection_seed_grid(2, 64)
.is_none(),
"Euclidean-patch (unbounded) atoms must not expose a projection seed grid"
);
}
#[test]
fn torus_projection_seed_grid_caps_total_points() {
let g1 = SaeAtomBasisKind::Torus
.projection_seed_grid(1, 256)
.unwrap();
assert_eq!(g1.dim(), (256, 1));
let g3 = SaeAtomBasisKind::Torus
.projection_seed_grid(3, 256)
.unwrap();
assert_eq!(g3.ncols(), 3);
assert_eq!(g3.nrows(), 16 * 16 * 16);
assert!(
g3.nrows() <= 4096,
"torus d=3 seed grid has {} points, over the 4096 cap",
g3.nrows()
);
assert!(
g3.iter().all(|&t| (0.0..1.0).contains(&t)),
"every torus seed coordinate must be a phase on [0, 1)"
);
for axis in 0..3 {
let mut vals: Vec<f64> = g3.column(axis).iter().copied().collect();
vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
vals.dedup();
assert_eq!(
vals.len(),
16,
"torus seed axis {axis} should take 16 distinct phases"
);
}
let g12 = SaeAtomBasisKind::Torus
.projection_seed_grid(12, 256)
.unwrap();
assert_eq!(g12.nrows(), 1usize << 12);
assert!(g12.nrows() <= 4096);
assert!(
SaeAtomBasisKind::Torus
.projection_seed_grid(13, 256)
.is_none(),
"torus d=13 seed grid (2^13 > 4096) must fall back to None, not blow up the cap"
);
}
#[test]
fn seed_coords_by_decoder_projection_lands_on_grid_minimiser() {
use std::f64::consts::PI;
let resolution = 8usize;
let init_coords = array![[0.05], [0.05]];
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let (phi0, jet0) = evaluator.evaluate(init_coords.view()).unwrap();
let decoder = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
decoder,
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(evaluator.clone());
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((2, 1)),
vec![init_coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let phases = [3usize, 6usize];
let mut target = Array2::<f64>::zeros((2, 2));
for (row, &k) in phases.iter().enumerate() {
let t = k as f64 / resolution as f64;
target[[row, 0]] = (2.0 * PI * t).sin();
target[[row, 1]] = (2.0 * PI * t).cos();
}
term.seed_coords_by_decoder_projection(target.view(), resolution)
.unwrap();
let seeded = term.assignment.coords[0].as_matrix();
let mut expected_coords = Array2::<f64>::zeros((2, 1));
for (row, &k) in phases.iter().enumerate() {
let expected = k as f64 / resolution as f64;
assert_abs_diff_eq!(seeded[[row, 0]], expected, epsilon = 1e-12);
expected_coords[[row, 0]] = expected;
}
let (phi_expected, _) = evaluator.evaluate(expected_coords.view()).unwrap();
assert_abs_diff_eq!(
(&term.atoms[0].basis_values - &phi_expected)
.mapv(f64::abs)
.sum(),
0.0,
epsilon = 1e-12
);
}
#[test]
fn seed_coords_by_decoder_projection_rejects_shape_mismatch() {
let init_coords = array![[0.05], [0.05]];
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(3).unwrap());
let (phi0, jet0) = evaluator.evaluate(init_coords.view()).unwrap();
let decoder = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
decoder,
Array2::<f64>::eye(3),
)
.unwrap()
.with_basis_evaluator(evaluator);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((2, 1)),
vec![init_coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let bad_target = Array2::<f64>::zeros((2, 3));
let err = term
.seed_coords_by_decoder_projection(bad_target.view(), 8)
.unwrap_err();
assert!(
err.contains("target shape"),
"expected a target-shape error, got: {err}"
);
}
#[test]
fn sphere_chart_basis_jet_is_single_source_of_truth() {
let coords = array![
[-1.2, -2.4], [0.35, 0.9], [std::f64::consts::FRAC_PI_2, 0.4], [-std::f64::consts::FRAC_PI_2, -1.1], [2.3, 0.7], [-3.0, 1.9], ];
let (engine_phi, engine_jet) = sphere_chart_basis_jet(coords.view()).unwrap();
let (adapter_phi, adapter_jet) = SphereChartEvaluator.evaluate(coords.view()).unwrap();
assert_eq!(engine_phi, adapter_phi);
assert_eq!(engine_jet, adapter_jet);
for row in 0..coords.nrows() {
let lat = coords[[row, 0]];
let lon = coords[[row, 1]];
let clat = lat.cos();
let slat = lat.sin();
let clon = lon.cos();
let slon = lon.sin();
let x = clat * clon;
let y = clat * slon;
let z = slat;
assert!((engine_phi[[row, 0]] - 1.0).abs() <= 1.0e-12);
assert!((engine_phi[[row, 1]] - x).abs() <= 1.0e-12);
assert!((engine_phi[[row, 2]] - y).abs() <= 1.0e-12);
assert!((engine_phi[[row, 3]] - z).abs() <= 1.0e-12);
assert!((engine_phi[[row, 4]] - x * y).abs() <= 1.0e-12);
assert!((engine_phi[[row, 5]] - y * z).abs() <= 1.0e-12);
assert!((engine_phi[[row, 6]] - x * z).abs() <= 1.0e-12);
let dx_dlon = -clat * slon;
let dy_dlon = clat * clon;
assert!((engine_jet[[row, 1, 1]] - dx_dlon).abs() <= 1.0e-12);
assert!((engine_jet[[row, 2, 1]] - dy_dlon).abs() <= 1.0e-12);
assert_eq!(engine_jet[[row, 3, 1]], 0.0);
assert!((engine_jet[[row, 4, 1]] - (dx_dlon * y + x * dy_dlon)).abs() <= 1.0e-12);
assert!((engine_jet[[row, 5, 1]] - dy_dlon * z).abs() <= 1.0e-12);
assert!((engine_jet[[row, 6, 1]] - dx_dlon * z).abs() <= 1.0e-12);
let dx_dlat = -slat * clon;
let dy_dlat = -slat * slon;
let dz_dlat = clat;
assert!((engine_jet[[row, 1, 0]] - dx_dlat).abs() <= 1.0e-12);
assert!((engine_jet[[row, 2, 0]] - dy_dlat).abs() <= 1.0e-12);
assert!((engine_jet[[row, 3, 0]] - dz_dlat).abs() <= 1.0e-12);
assert!((engine_jet[[row, 4, 0]] - (dx_dlat * y + x * dy_dlat)).abs() <= 1.0e-12);
assert!((engine_jet[[row, 5, 0]] - (dy_dlat * z + y * dz_dlat)).abs() <= 1.0e-12);
assert!((engine_jet[[row, 6, 0]] - (dx_dlat * z + x * dz_dlat)).abs() <= 1.0e-12);
}
assert_eq!(
SPHERE_CHART_PENALTY_DIAGONAL,
[1e-8, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0]
);
}
#[test]
fn sphere_chart_jet_matches_fd_at_clamp_boundary() {
let coords = array![
[std::f64::consts::FRAC_PI_2, 0.4], [-std::f64::consts::FRAC_PI_2, -1.1], [1.45, 2.0], [1.69, -0.3], [2.3, 0.7], [0.35, 0.9], ];
let (_, jet) = sphere_chart_basis_jet(coords.view()).unwrap();
let h = 1.0e-6;
for row in 0..coords.nrows() {
for axis in 0..2 {
let mut plus = coords.clone();
let mut minus = coords.clone();
plus[[row, axis]] += h;
minus[[row, axis]] -= h;
let (phi_p, _) = sphere_chart_basis_jet(plus.view()).unwrap();
let (phi_m, _) = sphere_chart_basis_jet(minus.view()).unwrap();
for col in 0..7 {
let fd = (phi_p[[row, col]] - phi_m[[row, col]]) / (2.0 * h);
let an = jet[[row, col, axis]];
assert!(
(fd - an).abs() <= 1.0e-7,
"row {row} col {col} axis {axis}: analytic {an} vs FD {fd}"
);
}
}
}
let eps = 1.0e-8;
let lon = 0.4;
let below = array![[std::f64::consts::FRAC_PI_2 - eps, lon]];
let above = array![[std::f64::consts::FRAC_PI_2 + eps, lon]];
let (phi_below, _) = sphere_chart_basis_jet(below.view()).unwrap();
let (phi_above, _) = sphere_chart_basis_jet(above.view()).unwrap();
for col in 0..7 {
assert!(
(phi_below[[0, col]] - phi_above[[0, col]]).abs() <= 1.0e-6,
"basis discontinuous across lat = π/2 at col {col}: \
{} vs {}",
phi_below[[0, col]],
phi_above[[0, col]]
);
}
}
fn assert_second_jet_matches_central_difference<E: SaeBasisSecondJet>(
evaluator: &E,
coords: Array2<f64>,
abs_tol: f64,
rel_tol: f64,
) -> Result<(), String> {
let epsilon = 1.0e-4;
let second = evaluator.second_jet(coords.view())?;
let (_phi, jet) = evaluator.evaluate(coords.view())?;
let (n_rows, n_basis, latent_dim, latent_dim_b) = second.dim();
assert_eq!(latent_dim, latent_dim_b);
assert_eq!((n_rows, n_basis, latent_dim), jet.dim());
for row in 0..n_rows {
for axis_c in 0..latent_dim {
let mut plus = coords.clone();
let mut minus = coords.clone();
plus[[row, axis_c]] += epsilon;
minus[[row, axis_c]] -= epsilon;
let (_, jet_plus) = evaluator.evaluate(plus.view()).unwrap();
let (_, jet_minus) = evaluator.evaluate(minus.view()).unwrap();
for basis in 0..n_basis {
for axis_a in 0..latent_dim {
let fd = (jet_plus[[row, basis, axis_a]] - jet_minus[[row, basis, axis_a]])
/ (2.0 * epsilon);
let analytic = second[[row, basis, axis_a, axis_c]];
let error = (analytic - fd).abs();
let threshold = abs_tol + rel_tol * analytic.abs().max(fd.abs());
assert!(
error <= threshold,
"row={row} basis={basis} axis_a={axis_a} axis_c={axis_c}: \
analytic={analytic:.12e}, fd={fd:.12e}, error={error:.12e}, \
threshold={threshold:.12e}"
);
}
}
}
}
for row in 0..n_rows {
for basis in 0..n_basis {
for axis_a in 0..latent_dim {
for axis_b in 0..latent_dim {
let h_ab = second[[row, basis, axis_a, axis_b]];
let h_ba = second[[row, basis, axis_b, axis_a]];
assert!(
(h_ab - h_ba).abs() <= 1.0e-12,
"second_jet not symmetric: row={row} basis={basis} \
({axis_a},{axis_b})={h_ab:.6e} vs ({axis_b},{axis_a})={h_ba:.6e}"
);
}
}
}
}
Ok(())
}
fn assert_third_jet_matches_central_difference<E: SaeBasisThirdJet>(
evaluator: &E,
coords: Array2<f64>,
abs_tol: f64,
rel_tol: f64,
) -> Result<(), String> {
let epsilon = 1.0e-4;
let third = evaluator.third_jet(coords.view())?;
let second = evaluator.second_jet(coords.view())?;
let (n_rows, n_basis, latent_dim, ld_b, ld_c) = third.dim();
assert_eq!(latent_dim, ld_b);
assert_eq!(latent_dim, ld_c);
assert_eq!((n_rows, n_basis, latent_dim, latent_dim), second.dim());
for row in 0..n_rows {
for axis_e in 0..latent_dim {
let mut plus = coords.clone();
let mut minus = coords.clone();
plus[[row, axis_e]] += epsilon;
minus[[row, axis_e]] -= epsilon;
let second_plus = evaluator.second_jet(plus.view())?;
let second_minus = evaluator.second_jet(minus.view())?;
for basis in 0..n_basis {
for axis_a in 0..latent_dim {
for axis_c in 0..latent_dim {
let fd = (second_plus[[row, basis, axis_a, axis_c]]
- second_minus[[row, basis, axis_a, axis_c]])
/ (2.0 * epsilon);
let analytic = third[[row, basis, axis_a, axis_c, axis_e]];
let error = (analytic - fd).abs();
let threshold = abs_tol + rel_tol * analytic.abs().max(fd.abs());
assert!(
error <= threshold,
"row={row} basis={basis} a={axis_a} c={axis_c} e={axis_e}: \
analytic={analytic:.12e}, fd={fd:.12e}, error={error:.6e}, \
threshold={threshold:.6e}"
);
}
}
}
}
}
for row in 0..n_rows {
for basis in 0..n_basis {
for a in 0..latent_dim {
for b in 0..latent_dim {
for c in 0..latent_dim {
let reference = third[[row, basis, a, b, c]];
for perm in [[a, c, b], [b, a, c], [b, c, a], [c, a, b], [c, b, a]] {
let permuted = third[[row, basis, perm[0], perm[1], perm[2]]];
assert!(
(reference - permuted).abs() <= 1.0e-10,
"third_jet not symmetric: row={row} basis={basis} \
({a},{b},{c})={reference:.6e} vs ({},{},{})={permuted:.6e}",
perm[0],
perm[1],
perm[2]
);
}
}
}
}
}
}
Ok(())
}
#[test]
fn isometry_periodic_second_jet_matches_fd() -> Result<(), String> {
assert_second_jet_matches_central_difference(
&PeriodicHarmonicEvaluator::new(7).unwrap(),
array![[-0.37], [0.0], [0.125], [0.41]],
1.0e-6,
1.0e-5,
)?;
Ok(())
}
#[test]
fn isometry_sphere_second_jet_matches_fd() -> Result<(), String> {
let sphere_coords = array![[-0.7, -1.2], [-0.25, 0.0], [0.35, 0.9], [0.8, 2.1]];
assert_second_jet_matches_central_difference(
&SphereChartEvaluator,
sphere_coords,
1.0e-6,
1.0e-5,
)?;
Ok(())
}
#[test]
fn isometry_torus_second_jet_matches_fd() -> Result<(), String> {
let torus_coords = array![[0.1, 0.7], [0.42, 0.0], [0.95, 0.33], [0.5, 0.5]];
let evaluator = TorusHarmonicEvaluator::new(2, 3).unwrap();
assert!(evaluator.basis_size() > 0);
assert_second_jet_matches_central_difference(&evaluator, torus_coords, 1.0e-6, 1.0e-5)?;
Ok(())
}
#[test]
fn isometry_periodic_third_jet_matches_fd() -> Result<(), String> {
assert_third_jet_matches_central_difference(
&PeriodicHarmonicEvaluator::new(7).unwrap(),
array![[-0.37], [0.0], [0.125], [0.41]],
1.0e-6,
1.0e-5,
)?;
Ok(())
}
#[test]
fn isometry_sphere_third_jet_matches_fd() -> Result<(), String> {
let sphere_coords = array![[-0.7, -1.2], [-0.25, 0.0], [0.35, 0.9], [0.8, 2.1]];
assert_third_jet_matches_central_difference(
&SphereChartEvaluator,
sphere_coords,
1.0e-6,
1.0e-5,
)?;
Ok(())
}
#[test]
fn isometry_torus_third_jet_matches_fd() -> Result<(), String> {
let torus_coords = array![[0.1, 0.7], [0.42, 0.0], [0.95, 0.33], [0.5, 0.5]];
let evaluator = TorusHarmonicEvaluator::new(2, 3).unwrap();
assert!(evaluator.basis_size() > 0);
assert_third_jet_matches_central_difference(&evaluator, torus_coords, 1.0e-6, 1.0e-5)?;
Ok(())
}
#[test]
fn isometry_affine_third_jet_is_trivial_zero() -> Result<(), String> {
let evaluator = AffineCoordinateEvaluator { latent_dim: 3 };
let coords = array![[0.2, -0.3, 0.7], [1.1, 0.0, -0.4]];
let third = evaluator.third_jet(coords.view())?;
assert_eq!(third.dim(), (coords.nrows(), 4, 3, 3, 3));
assert!(
third.iter().all(|x| *x == 0.0),
"affine third jet must vanish identically, got {third:?}"
);
Ok(())
}
#[test]
fn isometry_euclidean_patch_third_jet_matches_fd() -> Result<(), String> {
let evaluator = EuclideanPatchEvaluator::new(2, 4)?;
let coords = array![[0.2, -0.3], [0.7, 0.4], [-0.5, 0.9]];
assert_third_jet_matches_central_difference(&evaluator, coords, 1.0e-6, 1.0e-5)?;
Ok(())
}
#[test]
fn duchon_coordinate_evaluator_phi_and_jet_share_column_count() {
for (d, centers) in [
(1usize, array![[-1.0], [-0.4], [0.1], [0.6], [1.2], [1.9]]),
(
2usize,
array![
[-1.0, -0.8],
[-0.3, 0.4],
[0.2, -0.5],
[0.7, 0.9],
[1.1, -0.2],
[1.6, 0.6],
],
),
] {
let evaluator = DuchonCoordinateEvaluator::new(centers, 2).unwrap();
let coords = match d {
1 => array![[-0.5], [0.0], [0.3], [0.8]],
_ => array![[-0.5, 0.2], [0.0, -0.3], [0.3, 0.7], [0.8, -0.1]],
};
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
assert_eq!(
phi.ncols(),
jet.shape()[1],
"Duchon d={d}: Phi has {} columns but jet has {}",
phi.ncols(),
jet.shape()[1]
);
assert_eq!(jet.shape()[0], coords.nrows());
assert_eq!(jet.shape()[2], d);
}
}
#[test]
fn duchon_coordinate_evaluator_jacobian_matches_fd() {
let centers = array![
[-1.0, -0.8],
[-0.3, 0.4],
[0.2, -0.5],
[0.7, 0.9],
[1.1, -0.2],
[1.6, 0.6],
];
let evaluator = DuchonCoordinateEvaluator::new(centers, 2).unwrap();
let coords = array![[-0.5, 0.2], [0.05, -0.35], [0.45, 0.75], [1.3, 0.1]];
assert_jacobian_matches_central_difference(&evaluator, coords, 1.0e-4);
}
#[test]
fn duchon_coordinate_evaluator_second_jet_matches_fd() -> Result<(), String> {
let centers = array![
[-1.0, -0.8],
[-0.3, 0.4],
[0.2, -0.5],
[0.7, 0.9],
[1.1, -0.2],
[1.6, 0.6],
];
let evaluator = DuchonCoordinateEvaluator::new(centers, 2).unwrap();
let coords = array![[-0.5, 0.2], [0.05, -0.35], [0.45, 0.75], [1.3, 0.1]];
assert_second_jet_matches_central_difference(&evaluator, coords, 1.0e-4, 1.0e-4)?;
Ok(())
}
#[test]
fn duchon_coordinate_evaluator_third_jet_matches_fd() -> Result<(), String> {
let centers = array![
[-1.0, -0.8],
[-0.3, 0.4],
[0.2, -0.5],
[0.7, 0.9],
[1.1, -0.2],
[1.6, 0.6],
];
let evaluator = DuchonCoordinateEvaluator::new(centers, 2).unwrap();
let coords = array![[-0.5, 0.2], [0.05, -0.35], [0.45, 0.75], [1.3, 0.1]];
assert_third_jet_matches_central_difference(&evaluator, coords, 1.0e-4, 1.0e-4)?;
Ok(())
}
#[test]
fn euclidean_patch_evaluator_jets_match_fd() -> Result<(), String> {
let evaluator = EuclideanPatchEvaluator::new(2, 2).unwrap();
let coords = array![[0.0, -1.0], [3.5, 0.25], [-0.75, 1.2], [0.4, 0.9]];
assert_jacobian_matches_central_difference(&evaluator, coords.clone(), 1.0e-6);
assert_second_jet_matches_central_difference(&evaluator, coords, 1.0e-5, 1.0e-5)?;
let (phi, _jet) = evaluator.evaluate(array![[0.0, 0.0]].view())?;
assert_eq!(phi.ncols(), 6);
Ok(())
}
#[test]
fn euclidean_affine_gauge_canonicalization_preserves_reconstruction() -> Result<(), String> {
let evaluator = Arc::new(EuclideanPatchEvaluator::new(1, 2)?);
let canonical = array![[-1.0_f64], [-0.35], [0.1], [0.65], [1.2]];
let mut coords = canonical.clone();
for row in 0..coords.nrows() {
coords[[row, 0]] = 2.75 + 4.0 * canonical[[row, 0]];
}
let (phi, jet) = evaluator.evaluate(coords.view())?;
let decoder = array![[0.25, -0.4], [1.2, 0.3], [-0.15, 0.5]];
let atom = SaeManifoldAtom::new(
"euclidean_patch",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(evaluator.basis_size()),
)?
.with_basis_evaluator(evaluator);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((coords.nrows(), 1)),
vec![coords],
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)?;
let mut term = SaeManifoldTerm::new(vec![atom], assignment)?;
let before = term.fitted();
term.canonicalize_affine_gauge_after_accept()?;
let after = term.fitted();
let max_abs = before
.iter()
.zip(after.iter())
.fold(0.0_f64, |acc, (&a, &b)| acc.max((a - b).abs()));
assert!(
max_abs <= 1.0e-10,
"canonicalization changed reconstruction by {max_abs:.3e}"
);
let live = term.assignment.coords[0].as_matrix();
let mean = live.column(0).sum() / live.nrows() as f64;
let rms = (live.column(0).iter().map(|v| v * v).sum::<f64>() / live.nrows() as f64).sqrt();
assert_abs_diff_eq!(mean, 0.0, epsilon = 1.0e-12);
assert_abs_diff_eq!(rms, 1.0, epsilon = 1.0e-12);
Ok(())
}
#[test]
fn quotient_step_norm_removes_pure_euclidean_affine_gauge() -> Result<(), String> {
let evaluator = Arc::new(EuclideanPatchEvaluator::new(1, 2)?);
let coords = array![[-1.0_f64], [-0.4], [0.2], [0.8], [1.3]];
let (phi, jet) = evaluator.evaluate(coords.view())?;
let decoder = array![[0.1, -0.2], [1.0, 0.4], [0.25, -0.3]];
let atom = SaeManifoldAtom::new(
"euclidean_patch",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(evaluator.basis_size()),
)?
.with_basis_evaluator(evaluator);
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((coords.nrows(), 1)),
vec![coords],
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)?;
let term = SaeManifoldTerm::new(vec![atom], assignment)?;
let gauges = term.dense_step_gauge_vectors()?;
assert!(
gauges.len() >= 2,
"expected translation and scale gauge generators"
);
let n_coord = term.n_obs() * term.assignment.row_block_dim();
let gauge = &gauges[1];
let delta_t = gauge.slice(s![..n_coord]);
let delta_beta = gauge.slice(s![n_coord..]);
let raw = gauge.iter().map(|v| v * v).sum::<f64>();
let quotient = term.quotient_newton_step_norm_sq(delta_t, delta_beta, raw)?;
assert!(
quotient <= raw.max(1.0) * 1.0e-20,
"pure affine gauge step left quotient norm squared {quotient:.3e} from raw {raw:.3e}"
);
Ok(())
}
#[test]
fn sae_torus_atom_recovers_two_frequency_synthetic() {
let n = 96usize;
let p = 4usize;
let h = 3usize;
let d = 2usize;
let evaluator = TorusHarmonicEvaluator::new(d, h).unwrap();
let m = evaluator.basis_size();
let mut true_coords = Array2::<f64>::zeros((n, d));
for i in 0..n {
true_coords[[i, 0]] = ((i as f64) * 0.137).rem_euclid(1.0);
true_coords[[i, 1]] = ((i as f64) * 0.241 + 0.13).rem_euclid(1.0);
}
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
let t1 = 2.0 * std::f64::consts::PI * true_coords[[i, 0]];
let t2 = 2.0 * std::f64::consts::PI * true_coords[[i, 1]];
z[[i, 0]] = t1.sin() + 0.3 * t2.cos();
z[[i, 1]] = t1.cos() + 0.2 * (t1 + t2).sin();
z[[i, 2]] = t2.sin();
z[[i, 3]] = 0.5 * (t1 - t2).cos();
}
let sst: f64 = z.iter().map(|v| v * v).sum::<f64>();
let (phi0, jet0) = evaluator.evaluate(true_coords.view()).unwrap();
let mut penalty = Array2::<f64>::eye(m);
penalty *= 1.0e-4;
let atom = SaeManifoldAtom::new(
"torus_atom",
SaeAtomBasisKind::Torus,
d,
phi0,
jet0,
Array2::<f64>::zeros((m, p)),
penalty,
)
.unwrap()
.with_basis_evaluator(Arc::new(TorusHarmonicEvaluator::new(d, h).unwrap()));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![true_coords],
vec![LatentManifold::Product(vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
])],
AssignmentMode::softmax(0.5),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, -4.0, vec![Array1::<f64>::zeros(d)]);
let ridge = 1.0e-6;
for _ in 0..10 {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, ridge, ridge)
.unwrap();
if !loss.total().is_finite() {
break;
}
}
let fitted = term.fitted();
assert_eq!(fitted.dim(), (n, p));
let mut sse = 0.0_f64;
for ((row, col), v) in fitted.indexed_iter() {
let r = v - z[[row, col]];
sse += r * r;
}
let r2 = 1.0 - sse / sst.max(1.0e-12);
assert!(
r2 >= 0.5,
"torus atom R² too low: {r2:.4} (sst={sst:.4}, sse={sse:.4})"
);
}
#[test]
fn sae_sphere_atom_recovers_synthetic_signal() {
let n = 96usize;
let p = 3usize;
let d = 2usize;
let mut true_coords = Array2::<f64>::zeros((n, d));
for i in 0..n {
let t = (i as f64) / (n as f64);
true_coords[[i, 0]] = -0.5 + 1.0 * t; true_coords[[i, 1]] = -std::f64::consts::PI + 2.0 * std::f64::consts::PI * t;
}
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
let lat = true_coords[[i, 0]];
let lon = true_coords[[i, 1]];
let x = lat.cos() * lon.cos();
let y = lat.cos() * lon.sin();
let zc = lat.sin();
z[[i, 0]] = x;
z[[i, 1]] = y;
z[[i, 2]] = zc;
}
let sst: f64 = z.iter().map(|v| v * v).sum::<f64>();
let (phi0, jet0) = SphereChartEvaluator.evaluate(true_coords.view()).unwrap();
let m = phi0.ncols();
let mut penalty = Array2::<f64>::eye(m);
penalty *= 1.0e-4;
let atom = SaeManifoldAtom::new(
"sphere_atom",
SaeAtomBasisKind::Sphere,
d,
phi0,
jet0,
Array2::<f64>::zeros((m, p)),
penalty,
)
.unwrap()
.with_basis_evaluator(Arc::new(SphereChartEvaluator));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![true_coords],
vec![LatentManifold::Product(vec![
LatentManifold::Interval {
lo: -std::f64::consts::FRAC_PI_2,
hi: std::f64::consts::FRAC_PI_2,
},
LatentManifold::Circle {
period: std::f64::consts::TAU,
},
])],
AssignmentMode::softmax(0.5),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, -4.0, vec![Array1::<f64>::zeros(2)]);
let ridge = 1.0e-6;
for _ in 0..10 {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, ridge, ridge)
.unwrap();
if !loss.total().is_finite() {
break;
}
}
let fitted = term.fitted();
assert_eq!(fitted.dim(), (n, p));
let mut sse = 0.0_f64;
for ((row, col), v) in fitted.indexed_iter() {
let r = v - z[[row, col]];
sse += r * r;
}
let r2 = 1.0 - sse / sst.max(1.0e-12);
assert!(
r2 >= 0.5,
"sphere atom R² too low: {r2:.4} (sst={sst:.4}, sse={sse:.4})"
);
}
#[test]
fn sae_manifold_fit_10_steps_one_harmonic_reaches_high_r2() {
let n = 64usize;
let m = 3usize;
let p = 1usize;
let true_t: Vec<f64> = (0..n).map(|i| (i as f64) / (n as f64)).collect();
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
let angle = 2.0 * std::f64::consts::PI * true_t[i];
z[[i, 0]] = 0.7 * angle.sin() + 0.3 * angle.cos();
}
let sst: f64 = z.iter().map(|v| v * v).sum::<f64>();
let evaluator = PeriodicHarmonicEvaluator::new(m).unwrap();
let mut coords0_data = Array2::<f64>::zeros((n, 1));
for i in 0..n {
coords0_data[[i, 0]] = (true_t[i] + 0.25).rem_euclid(1.0);
}
let (phi0, jet0) = evaluator.evaluate(coords0_data.view()).unwrap();
let atom = SaeManifoldAtom::new(
"periodic_atom",
SaeAtomBasisKind::Periodic,
1,
phi0,
jet0,
Array2::<f64>::zeros((m, p)),
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(PeriodicHarmonicEvaluator::new(m).unwrap()));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, 1)),
vec![coords0_data],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::softmax(0.5),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1)]);
let max_iter = 10usize;
let learning_rate = 1.0;
let ridge = 1.0e-6;
let mut prev_total = f64::INFINITY;
for _ in 0..max_iter {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, learning_rate, ridge, ridge)
.unwrap();
let total = loss.total();
if !total.is_finite() {
break;
}
let denom = prev_total.abs().max(1.0e-12);
let rel = (prev_total - total).abs() / denom;
prev_total = total;
if rel < 1.0e-6 {
break;
}
}
let fitted = term.fitted();
assert_eq!(fitted.dim(), (n, p));
let mut ssr = 0.0;
for i in 0..n {
let r = z[[i, 0]] - fitted[[i, 0]];
ssr += r * r;
}
let r2 = 1.0 - ssr / sst.max(1.0e-12);
assert!(
r2 >= 0.95,
"10-step in-sample R² = {r2:.4} (ssr={ssr:.6}, sst={sst:.6}) should be >= 0.95"
);
}
#[test]
fn softmax_assignment_hessian_diag_is_available_for_k2() {
let n = 4usize;
let k = 2usize;
let logits =
Array2::<f64>::from_shape_fn((n, k), |(i, j)| 0.1 * (i as f64) - 0.2 * (j as f64));
let coords: Vec<Array2<f64>> = (0..k).map(|_| Array2::<f64>::zeros((n, 1))).collect();
let manifolds = vec![LatentManifold::Circle { period: 1.0 }; k];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coords,
manifolds,
AssignmentMode::softmax(0.7),
)
.unwrap();
let rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1); k]);
let (grad, diag) = assignment_prior_grad_hdiag(&assignment, &rho)
.expect("softmax assignment Hessian diagonal must be available");
assert_eq!(grad.len(), n * k);
assert_eq!(diag.len(), n * k);
assert!(grad.iter().all(|v| v.is_finite()));
assert!(diag.iter().all(|v| v.is_finite()));
}
#[test]
fn jumprelu_assignment_prior_hessian_diag_is_psd_over_logit_sweep() {
let n = 6usize;
let k = 2usize;
let temperature = 0.35_f64;
let threshold = 0.1_f64;
let logits = Array2::<f64>::from_shape_vec(
(n, k),
vec![
-2.0, -0.2, 0.0, 0.05, 0.1, 0.15, 0.4, 0.9, 1.5, 2.5, 4.0, 6.0,
],
)
.expect("valid logit grid");
let coords: Vec<Array2<f64>> = (0..k).map(|_| Array2::<f64>::zeros((n, 1))).collect();
let manifolds = vec![LatentManifold::Circle { period: 1.0 }; k];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits.clone(),
coords,
manifolds,
AssignmentMode::jumprelu(temperature, threshold),
)
.expect("valid JumpReLU assignment");
let rho = SaeManifoldRho::new(0.7_f64.ln(), -6.0, vec![Array1::<f64>::zeros(1); k]);
let (grad, diag) = assignment_prior_grad_hdiag(&assignment, &rho)
.expect("JumpReLU assignment prior hessian diag");
let inv_tau = 1.0 / temperature;
let inv_tau2 = inv_tau * inv_tau;
let sparsity_strength = rho.log_lambda_sparse.exp();
assert_eq!(grad.len(), n * k);
assert_eq!(diag.len(), n * k);
for (idx, &entry) in diag.iter().enumerate() {
let logit = logits[[idx / k, idx % k]];
let sparsity = if jumprelu_in_optimization_band(logit, threshold, temperature) {
let activation =
crate::linalg::utils::stable_logistic((logit - threshold) * inv_tau);
let slope = activation * (1.0 - activation);
sparsity_strength * slope * slope * inv_tau2
} else {
0.0
};
let expected = sparsity;
assert!(
entry.is_finite() && entry >= 0.0,
"JumpReLU gated hessian_diag majorizer must be finite and non-negative at index \
{idx}; entry={entry}"
);
assert_abs_diff_eq!(entry, expected, epsilon = 1e-12);
}
}
#[test]
fn ibp_map_k2_periodic_torus_recovers_signal_with_lsq_init() {
use crate::linalg::faer_ndarray::{FaerCholesky, fast_ata, fast_atb};
use faer::Side as FaerSide;
let n = 200usize;
let p = 8usize;
let k = 2usize;
let m = 5usize;
let mut theta = Array2::<f64>::zeros((n, 2));
for i in 0..n {
theta[[i, 0]] = ((i as f64) * 0.07) % 1.0;
theta[[i, 1]] = ((i as f64) * 0.13 + 0.31) % 1.0;
}
let mut raw = Array2::<f64>::zeros((n, 4));
for i in 0..n {
let a1 = 2.0 * std::f64::consts::PI * theta[[i, 0]];
let a2 = 2.0 * std::f64::consts::PI * theta[[i, 1]];
raw[[i, 0]] = a1.cos();
raw[[i, 1]] = a1.sin();
raw[[i, 2]] = a2.cos();
raw[[i, 3]] = a2.sin();
}
let mix = Array2::<f64>::from_shape_fn((4, p), |(i, j)| {
((i as f64 + 1.0) * 0.37 + (j as f64) * 0.21).sin()
});
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j in 0..p {
let mut acc = 0.0;
for r in 0..4 {
acc += raw[[i, r]] * mix[[r, j]];
}
z[[i, j]] = acc;
}
}
let mut col_mean = Array1::<f64>::zeros(p);
for j in 0..p {
let mut acc = 0.0;
for i in 0..n {
acc += z[[i, j]];
}
col_mean[j] = acc / n as f64;
}
for i in 0..n {
for j in 0..p {
z[[i, j]] -= col_mean[j];
}
}
let mut coords_k = vec![Array2::<f64>::zeros((n, 1)); k];
for i in 0..n {
coords_k[0][[i, 0]] = (theta[[i, 0]] + 0.05).rem_euclid(1.0);
coords_k[1][[i, 0]] = (theta[[i, 1]] + 0.07).rem_euclid(1.0);
}
let evaluator = PeriodicHarmonicEvaluator::new(m).unwrap();
let mut phi_k = Vec::with_capacity(k);
let mut jet_k = Vec::with_capacity(k);
for atom_idx in 0..k {
let (phi, jet) = evaluator.evaluate(coords_k[atom_idx].view()).unwrap();
phi_k.push(phi);
jet_k.push(jet);
}
let m_total = k * m;
let mut x = Array2::<f64>::zeros((n, m_total));
for atom_idx in 0..k {
for i in 0..n {
for col in 0..m {
x[[i, atom_idx * m + col]] = 0.5 * phi_k[atom_idx][[i, col]];
}
}
}
let mut xtx = fast_ata(&x);
let mut trace = 0.0_f64;
for i in 0..m_total {
trace += xtx[[i, i]];
}
let jitter = (trace / m_total as f64).max(1.0) * 1.0e-8;
for i in 0..m_total {
xtx[[i, i]] += jitter;
}
let xtz = fast_atb(&x, &z);
let b_joint = xtx
.cholesky(FaerSide::Lower)
.expect("LSQ Cholesky")
.solve_mat(&xtz);
let mut atoms = Vec::with_capacity(k);
for atom_idx in 0..k {
let mut b = Array2::<f64>::zeros((m, p));
for col in 0..m {
for j in 0..p {
b[[col, j]] = b_joint[[atom_idx * m + col, j]];
}
}
let atom = SaeManifoldAtom::new(
format!("torus_atom_{atom_idx}"),
SaeAtomBasisKind::Periodic,
1,
phi_k[atom_idx].clone(),
jet_k[atom_idx].clone(),
b,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(PeriodicHarmonicEvaluator::new(m).unwrap()));
atoms.push(atom);
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, k)),
coords_k,
vec![LatentManifold::Circle { period: 1.0 }; k],
AssignmentMode::ibp_map(0.7, 1.0, false),
)
.unwrap();
let mut term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let mut rho = SaeManifoldRho::new((0.02_f64).ln(), -6.0, vec![Array1::<f64>::zeros(1); k]);
let mut prev_total = f64::INFINITY;
for _ in 0..30 {
let loss = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, 1.0e-6, 1.0e-6)
.unwrap();
let total = loss.total();
if !total.is_finite() {
break;
}
let denom = prev_total.abs().max(1.0e-12);
let rel = (prev_total - total).abs() / denom;
prev_total = total;
if rel < 1.0e-6 {
break;
}
}
let fitted = term.fitted();
let mut ssr = 0.0;
let mut sst = 0.0;
for i in 0..n {
for j in 0..p {
let r = z[[i, j]] - fitted[[i, j]];
ssr += r * r;
sst += z[[i, j]] * z[[i, j]];
}
}
let r2 = 1.0 - ssr / sst.max(1.0e-12);
assert!(
r2 > 0.5,
"K=2 periodic torus IBP-MAP R² = {r2:.4} (ssr={ssr:.4}, sst={sst:.4}) should be > 0.5 with LSQ-seeded decoder"
);
let assignments = term.assignment.assignments();
let mean_active: f64 = assignments.iter().copied().sum::<f64>() / (n as f64);
assert!(
mean_active > 0.2,
"mean active mass across rows = {mean_active:.4} should exceed 0.2; assignment did not collapse"
);
}
#[test]
fn softmax_k2_periodic_completes_joint_fit_step() {
let n = 64usize;
let p = 4usize;
let k = 2usize;
let m = 3usize;
let mut z = Array2::<f64>::zeros((n, p));
for i in 0..n {
let a = 2.0 * std::f64::consts::PI * (i as f64) / (n as f64);
z[[i, 0]] = a.sin();
z[[i, 1]] = a.cos();
z[[i, 2]] = (2.0 * a).sin();
z[[i, 3]] = (2.0 * a).cos();
}
let evaluator = PeriodicHarmonicEvaluator::new(m).unwrap();
let mut coords_k = vec![Array2::<f64>::zeros((n, 1)); k];
for i in 0..n {
coords_k[0][[i, 0]] = (i as f64) / (n as f64);
coords_k[1][[i, 0]] = ((i as f64) * 2.0 / (n as f64)).rem_euclid(1.0);
}
let mut atoms = Vec::new();
for atom_idx in 0..k {
let (phi, jet) = evaluator.evaluate(coords_k[atom_idx].view()).unwrap();
let b = Array2::<f64>::from_shape_fn((m, p), |(i, j)| {
0.1 * ((i as f64 + 1.0) * (j as f64 + 1.0)).sin()
});
let atom = SaeManifoldAtom::new(
format!("a_{atom_idx}"),
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
b,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_evaluator(Arc::new(PeriodicHarmonicEvaluator::new(m).unwrap()));
atoms.push(atom);
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::zeros((n, k)),
coords_k,
vec![LatentManifold::Circle { period: 1.0 }; k],
AssignmentMode::softmax(0.7),
)
.unwrap();
let mut term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let mut rho = SaeManifoldRho::new(0.0, -6.0, vec![Array1::<f64>::zeros(1); k]);
let loss0 = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, 1.0e-6, 1.0e-6)
.expect("softmax K=2 must complete first joint-fit step");
assert!(loss0.total().is_finite());
let loss1 = term
.run_joint_fit_arrow_schur(z.view(), &mut rho, None, 1, 1.0, 1.0e-6, 1.0e-6)
.expect("softmax K=2 must complete second joint-fit step");
assert!(loss1.total().is_finite());
}
fn assert_isometry_wiring_matches_fd(
evaluator: Arc<dyn SaeBasisSecondJet>,
coords: Array2<f64>,
) {
let n_obs = coords.nrows();
let latent_dim = coords.ncols();
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let p: usize = 3;
let mut decoder = Array2::<f64>::zeros((m, p));
for i in 0..m {
for j in 0..p {
let x = (i as f64) * 0.371 + (j as f64) * 0.193 + 0.5;
decoder[[i, j]] = (x.sin() * 0.9) + 0.1 * ((i + j) as f64).cos();
}
}
let smooth = Array2::<f64>::eye(m);
let atom = SaeManifoldAtom::new(
"iso_wire_test",
SaeAtomBasisKind::Periodic,
latent_dim,
phi.clone(),
jet.clone(),
decoder.clone(),
smooth,
)
.unwrap()
.with_basis_second_jet(evaluator);
let target_slice = PsiSlice::full(n_obs * latent_dim, Some(latent_dim));
let penalty = IsometryPenalty::new_euclidean(target_slice, p);
let rho = Array1::<f64>::zeros(1);
let target_flat: Array1<f64> = coords.iter().copied().collect();
let v0 = penalty.value(target_flat.view(), rho.view());
assert_eq!(v0, IsometryPenalty::DEFAULT_VALUE_ON_MISSING_CACHE);
let g0 = penalty.grad_target(target_flat.view(), rho.view());
assert!(
g0.iter().all(|x| *x == 0.0),
"grad_target without cache must be all zeros, got {g0:?}"
);
let installed_second =
refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
assert!(
installed_second,
"evaluator must implement second_jet for this oracle to run"
);
let value = penalty.value(target_flat.view(), rho.view());
assert!(
value > 1.0e-6,
"expected non-trivial isometry loss after cache refresh, got {value}"
);
let grad = penalty.grad_target(target_flat.view(), rho.view());
assert_eq!(grad.len(), target_flat.len());
let max_abs = grad.iter().fold(0.0_f64, |acc, x| acc.max(x.abs()));
assert!(
max_abs > 1.0e-6,
"expected non-zero isometry gradient on at least one component, max |grad|={max_abs}"
);
let h_fd = 1.0e-5;
let probe_idx = 0usize; let mut coords_plus = coords.clone();
coords_plus[[0, 0]] += h_fd;
let mut coords_minus = coords.clone();
coords_minus[[0, 0]] -= h_fd;
refresh_isometry_caches_from_atom(&penalty, &atom, coords_plus.view()).unwrap();
let target_plus: Array1<f64> = coords_plus.iter().copied().collect();
let v_plus = penalty.value(target_plus.view(), rho.view());
refresh_isometry_caches_from_atom(&penalty, &atom, coords_minus.view()).unwrap();
let target_minus: Array1<f64> = coords_minus.iter().copied().collect();
let v_minus = penalty.value(target_minus.view(), rho.view());
refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
let grad_base = penalty.grad_target(target_flat.view(), rho.view());
let fd = (v_plus - v_minus) / (2.0 * h_fd);
let analytic = grad_base[probe_idx];
assert!(
(analytic - fd).abs() <= 1.0e-3 + 1.0e-4 * analytic.abs().max(fd.abs()),
"isometry grad/FD mismatch at coord 0: analytic={analytic:.6e}, fd={fd:.6e}"
);
}
#[test]
fn isometry_wiring_periodic_matches_fd() {
assert_isometry_wiring_matches_fd(
Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap()),
array![[0.12], [0.37], [0.58], [0.81]],
);
}
#[test]
fn isometry_wiring_sphere_matches_fd() {
assert_isometry_wiring_matches_fd(
Arc::new(SphereChartEvaluator),
array![[-0.5, 0.3], [0.2, -1.1], [0.7, 0.9]],
);
}
#[test]
fn isometry_wiring_torus_matches_fd() {
assert_isometry_wiring_matches_fd(
Arc::new(TorusHarmonicEvaluator::new(2, 2).unwrap()),
array![[0.13, 0.42], [0.66, 0.19], [0.88, 0.55]],
);
}
fn deterministic_decoder(n_basis: usize, p_out: usize, seed: f64) -> Array2<f64> {
Array2::<f64>::from_shape_fn((n_basis, p_out), |(i, j)| {
let x = seed + 0.371 * (i as f64) - 0.193 * (j as f64) + 0.047 * ((i * j + 1) as f64);
0.8 * x.sin() + 0.35 * (1.7 * x).cos()
})
}
fn build_isometry_atom_for_evaluator(
evaluator: Arc<dyn SaeBasisSecondJet>,
kind: SaeAtomBasisKind,
coords: &Array2<f64>,
p_out: usize,
seed: f64,
) -> (SaeManifoldAtom, IsometryPenalty, Array1<f64>) {
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let decoder = deterministic_decoder(m, p_out, seed);
let atom = SaeManifoldAtom::new(
"exact_hvp_atom",
kind,
coords.ncols(),
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_second_jet(evaluator);
let target_flat: Array1<f64> = coords.iter().copied().collect();
let penalty = IsometryPenalty::new_euclidean(
PsiSlice::full(target_flat.len(), Some(coords.ncols())),
p_out,
);
(atom, penalty, target_flat)
}
fn assert_exact_isometry_hvp_matches_grad_fd(
evaluator: Arc<dyn SaeBasisSecondJet>,
kind: SaeAtomBasisKind,
coords: Array2<f64>,
p_out: usize,
direction: Array2<f64>,
) {
let (atom, penalty, target_flat) =
build_isometry_atom_for_evaluator(evaluator, kind, &coords, p_out, 0.91);
let rho = array![0.0_f64];
let installed = refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
assert!(
installed,
"second-jet cache must be installed for exact HVP test"
);
assert!(
penalty.third_decoder_derivative().is_some(),
"non-Duchon exact HVP requires a live refreshed third-decoder-jet cache"
);
let v: Array1<f64> = direction.iter().copied().collect();
let exact = penalty.hvp(target_flat.view(), rho.view(), v.view());
assert!(
exact.iter().any(|x| x.abs() > 1.0e-7),
"exact isometry HVP should be nonzero after K refresh; got {exact:?}"
);
let eps = 1.0e-6;
let coords_plus = &coords + &(direction.mapv(|x| eps * x));
let coords_minus = &coords - &(direction.mapv(|x| eps * x));
let target_plus: Array1<f64> = coords_plus.iter().copied().collect();
let target_minus: Array1<f64> = coords_minus.iter().copied().collect();
refresh_isometry_caches_from_atom(&penalty, &atom, coords_plus.view()).unwrap();
let grad_plus = penalty.grad_target(target_plus.view(), rho.view());
refresh_isometry_caches_from_atom(&penalty, &atom, coords_minus.view()).unwrap();
let grad_minus = penalty.grad_target(target_minus.view(), rho.view());
refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
let fd = (&grad_plus - &grad_minus).mapv(|x| x / (2.0 * eps));
for i in 0..exact.len() {
let err = (exact[i] - fd[i]).abs();
let tol = 2.0e-4 + 3.0e-5 * exact[i].abs().max(fd[i].abs());
assert!(
err <= tol,
"exact isometry HVP/grad-FD mismatch at flat index {i}: exact={:.12e}, fd={:.12e}, err={:.6e}, tol={:.6e}",
exact[i],
fd[i],
err,
tol
);
}
}
fn assert_exact_isometry_hvp_collapses_to_gn_at_zero_residual(
evaluator: Arc<dyn SaeBasisSecondJet>,
kind: SaeAtomBasisKind,
coords: Array2<f64>,
p_out: usize,
direction: Array2<f64>,
) {
let (atom, penalty, target_flat) =
build_isometry_atom_for_evaluator(evaluator, kind, &coords, p_out, 1.37);
let rho = array![0.0_f64];
let d = coords.ncols();
refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
let mut g_ref = penalty
.pullback_metric(d)
.expect("pullback metric is available after the cache refresh");
let mut trace_sum = 0.0_f64;
for row in 0..g_ref.nrows() {
for axis in 0..d {
trace_sum += g_ref[[row, axis * d + axis]];
}
}
let normalizer = trace_sum / (g_ref.nrows() * d) as f64;
for value in g_ref.iter_mut() {
*value /= normalizer;
}
let penalty = penalty.with_reference(IsometryReference::UserSupplied(Arc::new(g_ref)));
assert!(
penalty.third_decoder_derivative().is_some(),
"zero-residual exact/GN test must still carry the real refreshed K cache"
);
let v: Array1<f64> = direction.iter().copied().collect();
let exact = penalty.hvp(target_flat.view(), rho.view(), v.view());
let gn = penalty.psd_majorizer_hvp(target_flat.view(), rho.view(), v.view());
assert!(
gn.iter().any(|x| x.abs() > 1.0e-8),
"GN block should be nonzero so exact/GN equality is not vacuous"
);
for i in 0..exact.len() {
assert_abs_diff_eq!(exact[i], gn[i], epsilon = 1.0e-10);
}
}
#[test]
fn isometry_exact_hvp_sphere_matches_grad_fd_and_uses_refreshed_k() {
assert_exact_isometry_hvp_matches_grad_fd(
Arc::new(SphereChartEvaluator),
SaeAtomBasisKind::Sphere,
array![[-0.61, 0.23], [-0.18, -1.07], [0.42, 0.81], [0.73, -0.39]],
4,
array![[0.31, -0.27], [-0.18, 0.22], [0.14, 0.19], [-0.25, -0.11]],
);
}
#[test]
fn isometry_exact_hvp_torus_matches_grad_fd_and_uses_refreshed_k() {
assert_exact_isometry_hvp_matches_grad_fd(
Arc::new(TorusHarmonicEvaluator::new(2, 2).unwrap()),
SaeAtomBasisKind::Torus,
array![[0.13, 0.42], [0.66, 0.19], [0.88, 0.55]],
3,
array![[0.21, -0.16], [-0.24, 0.18], [0.13, 0.27]],
);
}
#[test]
fn isometry_exact_hvp_sphere_and_torus_collapse_to_gn_at_zero_residual() {
assert_exact_isometry_hvp_collapses_to_gn_at_zero_residual(
Arc::new(SphereChartEvaluator),
SaeAtomBasisKind::Sphere,
array![[-0.52, 0.17], [-0.11, -0.93], [0.39, 0.74]],
4,
array![[0.17, -0.21], [-0.13, 0.08], [0.22, 0.19]],
);
assert_exact_isometry_hvp_collapses_to_gn_at_zero_residual(
Arc::new(TorusHarmonicEvaluator::new(2, 2).unwrap()),
SaeAtomBasisKind::Torus,
array![[0.19, 0.31], [0.57, 0.73], [0.84, 0.12]],
3,
array![[0.11, -0.14], [-0.20, 0.07], [0.16, 0.23]],
);
}
fn assert_isometry_psd_majorizer_live_after_atom_refresh(
evaluator: Arc<dyn SaeBasisSecondJet>,
kind: SaeAtomBasisKind,
coords: Array2<f64>,
p_out: usize,
probes: &[Array2<f64>],
) {
let (atom, penalty, target_flat) =
build_isometry_atom_for_evaluator(evaluator, kind, &coords, p_out, 0.53);
let rho = array![0.0_f64];
let n = target_flat.len();
let unit0 = {
let mut e = Array1::<f64>::zeros(n);
e[0] = 1.0;
e
};
let pre = penalty.psd_majorizer_hvp(target_flat.view(), rho.view(), unit0.view());
assert!(
pre.iter().all(|x| *x == 0.0),
"psd_majorizer_hvp without a cache must be the zero block; got {pre:?}"
);
let installed = refresh_isometry_caches_from_atom(&penalty, &atom, coords.view()).unwrap();
assert!(
installed,
"second-jet cache must install for the PSD-majorizer liveness test"
);
let d = coords.ncols();
let g = penalty
.pullback_metric(d)
.expect("pullback metric available after refresh");
let mut trace_sum = 0.0_f64;
for row in 0..g.nrows() {
for axis in 0..d {
trace_sum += g[[row, axis * d + axis]];
}
}
let normalizer = trace_sum / (g.nrows() * d) as f64;
let mut residual_mass = 0.0_f64;
for row in 0..g.nrows() {
for a in 0..d {
for b in 0..d {
let g_ref = if a == b { 1.0 } else { 0.0 };
residual_mass += (g[[row, a * d + b]] / normalizer - g_ref).abs();
}
}
}
assert!(
residual_mass > 1.0e-3,
"Euclidean-reference residual must be nonzero for a real curvature test; \
got residual mass {residual_mass:.3e}"
);
let mut bmat = Array2::<f64>::zeros((n, n));
for k in 0..n {
let mut e = Array1::<f64>::zeros(n);
e[k] = 1.0;
let col = penalty.psd_majorizer_hvp(target_flat.view(), rho.view(), e.view());
for r in 0..n {
bmat[[r, k]] = col[r];
}
}
let max_abs = bmat.iter().fold(0.0_f64, |acc, x| acc.max(x.abs()));
assert!(
max_abs > 1.0e-6,
"isometry GN majorizer must be nonzero for a non-Duchon basis after refresh; \
max |B| = {max_abs:.3e}"
);
for r in 0..n {
for c in 0..n {
assert_abs_diff_eq!(bmat[[r, c]], bmat[[c, r]], epsilon = 1.0e-10);
}
}
for probe in probes {
let v: Array1<f64> = probe.iter().copied().collect();
assert_eq!(v.len(), n, "probe must match the flattened target length");
let bv = penalty.psd_majorizer_hvp(target_flat.view(), rho.view(), v.view());
let quad = v.dot(&bv);
assert!(
quad >= -1.0e-9,
"isometry GN majorizer must be PSD; got vᵀBv = {quad:.3e}"
);
}
}
#[test]
fn isometry_psd_majorizer_live_after_sphere_refresh() {
assert_isometry_psd_majorizer_live_after_atom_refresh(
Arc::new(SphereChartEvaluator),
SaeAtomBasisKind::Sphere,
array![[-0.61, 0.23], [-0.18, -1.07], [0.42, 0.81]],
4,
&[
array![[0.31, -0.27], [-0.18, 0.22], [0.14, 0.19]],
array![[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]],
array![[-2.3, 0.6], [-0.1, 1.4], [0.8, -1.7]],
],
);
}
#[test]
fn isometry_psd_majorizer_live_after_circle_refresh() {
assert_isometry_psd_majorizer_live_after_atom_refresh(
Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap()),
SaeAtomBasisKind::Periodic,
array![[0.12], [0.37], [0.58], [0.81]],
3,
&[
array![[0.4], [-1.1], [0.7], [0.3]],
array![[1.0], [1.0], [1.0], [1.0]],
array![[-2.3], [0.6], [-0.1], [1.4]],
],
);
}
#[test]
fn isometry_psd_majorizer_live_after_torus_refresh() {
assert_isometry_psd_majorizer_live_after_atom_refresh(
Arc::new(TorusHarmonicEvaluator::new(2, 2).unwrap()),
SaeAtomBasisKind::Torus,
array![[0.13, 0.42], [0.66, 0.19], [0.88, 0.55]],
3,
&[
array![[0.21, -0.16], [-0.24, 0.18], [0.13, 0.27]],
array![[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]],
array![[-1.2, 0.5], [0.3, -0.9], [0.7, 0.2]],
],
);
}
#[test]
fn refresh_isometry_caches_pairs_each_penalty_to_its_own_atom() {
let latent_dim = 1usize;
let p_out = 3usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(5).unwrap());
let coords0 = array![[0.05], [0.20], [0.55], [0.80]];
let coords1 = array![[0.13], [0.41], [0.62], [0.91]];
let build_atom = |name: &str, coords: &Array2<f64>, seed: f64| {
let (phi, jet) = evaluator.evaluate(coords.view()).unwrap();
let m = phi.ncols();
let mut decoder = Array2::<f64>::zeros((m, p_out));
for i in 0..m {
for j in 0..p_out {
let x = (i as f64) * 0.371 + (j as f64) * 0.193 + seed;
decoder[[i, j]] = (x.sin() * 0.9) + 0.1 * ((i + j) as f64).cos();
}
}
let smooth = Array2::<f64>::eye(m);
SaeManifoldAtom::new(
name,
SaeAtomBasisKind::Periodic,
latent_dim,
phi,
jet,
decoder,
smooth,
)
.unwrap()
.with_basis_second_jet(evaluator.clone() as Arc<dyn SaeBasisSecondJet>)
};
let atom0 = build_atom("atom0", &coords0, 0.5);
let atom1 = build_atom("atom1", &coords1, 1.7);
let slice0 = PsiSlice::full(coords0.nrows() * latent_dim, Some(latent_dim));
let control0 = IsometryPenalty::new_euclidean(slice0, p_out);
refresh_isometry_caches_from_atom(&control0, &atom0, coords0.view()).unwrap();
let expected0 = control0
.jacobian_cache()
.expect("control penalty 0 must have a Jacobian cache");
let slice1 = PsiSlice::full(coords1.nrows() * latent_dim, Some(latent_dim));
let control1 = IsometryPenalty::new_euclidean(slice1, p_out);
refresh_isometry_caches_from_atom(&control1, &atom1, coords1.view()).unwrap();
let expected1 = control1
.jacobian_cache()
.expect("control penalty 1 must have a Jacobian cache");
assert_ne!(
*expected0, *expected1,
"atom 0 and atom 1 must produce distinct Jacobian caches"
);
let logits = Array2::<f64>::zeros((coords0.nrows(), 2));
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![coords0.clone(), coords1.clone()],
vec![
LatentManifold::Circle { period: 1.0 },
LatentManifold::Circle { period: 1.0 },
],
AssignmentMode::ibp_map(0.7, 1.0, true),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom0, atom1], assignment).unwrap();
let mut registry = AnalyticPenaltyRegistry::new();
let pslice0 = PsiSlice::full(coords0.nrows() * latent_dim, Some(latent_dim));
let pslice1 = PsiSlice::full(coords1.nrows() * latent_dim, Some(latent_dim));
registry.push(AnalyticPenaltyKind::Isometry(Arc::new(
IsometryPenalty::new_euclidean(pslice0, p_out),
)));
registry.push(AnalyticPenaltyKind::Isometry(Arc::new(
IsometryPenalty::new_euclidean(pslice1, p_out),
)));
let coords_per_atom = vec![coords0.clone(), coords1.clone()];
let refreshed =
refresh_isometry_caches_from_term(®istry, &term, &coords_per_atom).unwrap();
assert_eq!(refreshed, 2, "both penalties should install second caches");
let cache0 = match ®istry.penalties[0] {
AnalyticPenaltyKind::Isometry(p) => p
.jacobian_cache()
.expect("penalty 0 cache must be populated"),
_ => panic!("expected isometry penalty at index 0"),
};
let cache1 = match ®istry.penalties[1] {
AnalyticPenaltyKind::Isometry(p) => p
.jacobian_cache()
.expect("penalty 1 cache must be populated"),
_ => panic!("expected isometry penalty at index 1"),
};
assert_eq!(
*cache0, *expected0,
"penalty 0 must be refreshed against atom 0"
);
assert_eq!(
*cache1, *expected1,
"penalty 1 must be refreshed against atom 1 (regression: old find() paired it to atom 0)"
);
assert_ne!(
*cache0, *cache1,
"the two penalties must not collapse onto the same atom"
);
}
fn warmstart_test_objective() -> SaeManifoldOuterObjective {
let coords = array![[0.10], [0.35], [0.62], [0.88]];
let (phi, jet) = periodic_basis(&coords);
let atom = SaeManifoldAtom::new(
"periodic",
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
array![[0.30], [-0.20], [0.15]],
Array2::<f64>::eye(3),
)
.unwrap();
let assignment = SaeAssignment::from_blocks_with_mode(
array![[0.9_f64], [0.8], [0.7], [0.6]],
vec![coords],
AssignmentMode::softmax(0.7),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let target = array![[0.20_f64], [-0.10], [0.30], [0.05]];
let rho = SaeManifoldRho::new(0.0, 0.0, vec![Array1::<f64>::zeros(1)]);
SaeManifoldOuterObjective::new(term, target, None, rho, 8, 1.0, 1.0e-6, 1.0e-6)
}
fn near_singular_outer_gradient_cache() -> ArrowFactorCache {
ArrowFactorCache {
htt_factors: ArrowFactorSlab::from_blocks(vec![array![[1.0_f64, 0.0], [0.0, 1.0e-7]]]),
htt_factors_undamped: ArrowUndampedFactors::SameAsDamped,
schur_factor: Some(array![[1.0_f64]]),
solver_mode: ArrowSolverMode::Direct,
ridge_t: 0.0,
ridge_beta: 0.0,
htbeta: ArrowHtbetaCache::Disabled { estimated_bytes: 0 },
d: 2,
row_dims: Arc::from(vec![2usize].into_boxed_slice()),
row_offsets: Arc::from(vec![0usize, 2usize].into_boxed_slice()),
k: 1,
manifold_mode_fingerprint: 0,
row_hessian_fingerprint: 0,
pcg_diagnostics: PcgDiagnostics::default(),
}
}
#[test]
fn outer_gradient_conditioning_guard_rejects_near_singular_cache() {
let cache = near_singular_outer_gradient_cache();
let err = SaeManifoldOuterObjective::ensure_outer_gradient_factor_well_conditioned(&cache)
.expect_err("near-singular evidence factor must reject the analytic outer gradient");
assert!(
err.contains("analytic outer gradient undefined at this rho"),
"guard error must name the undefined analytic-gradient condition; got: {err}"
);
assert!(
err.contains("min/max pivot ratio") && err.contains("floor"),
"guard error must report the pivot ratio and floor; got: {err}"
);
}
#[test]
fn seed_inner_state_accepts_empty_beta_as_noslot() {
let mut obj = warmstart_test_objective();
let empty: Array1<f64> = Array1::zeros(0);
let outcome = obj
.seed_inner_state(&empty)
.expect("empty-β seed must be accepted as a no-op, not rejected (gam#577/#579)");
assert!(
matches!(outcome, SeedOutcome::NoSlot),
"empty-β seed must report NoSlot (proceed cold); got {outcome:?}"
);
}
#[test]
fn seed_inner_state_installs_and_reuses_matching_beta() {
let mut obj = warmstart_test_objective();
let dim = obj.term.beta_dim();
let pristine = obj.term.flatten_beta();
let seed: Array1<f64> =
Array1::from_shape_fn(dim, |i| pristine[i] + 0.5 + 0.01 * (i as f64));
assert!(
(&seed - &pristine).iter().any(|d| d.abs() > 1e-6),
"seed must differ from the pristine β for the reuse check to be meaningful"
);
let outcome = obj
.seed_inner_state(&seed)
.expect("a length-matching β must install");
assert!(
matches!(outcome, SeedOutcome::Installed),
"matching β must report Installed; got {outcome:?}"
);
obj.inner_max_iter = 0;
let rho_flat = obj.baseline_rho.to_flat();
let eval = OuterObjective::eval(&mut obj, &rho_flat)
.expect("eval at the warm-started β must succeed");
let hint = eval
.inner_beta_hint
.expect("the SAE objective must publish inner_beta_hint for continuation reuse");
assert_eq!(
hint.len(),
dim,
"published hint must have decoder dimension"
);
for (i, (&h, &s)) in hint.iter().zip(seed.iter()).enumerate() {
assert!(
(h - s).abs() < 1e-12,
"warm-started β must be reused verbatim by the inner solve at coord {i}: \
hint {h} != seed {s} (gam#577/#579)"
);
}
}
#[test]
fn seed_inner_state_rejects_wrong_length_populated_beta() {
let mut obj = warmstart_test_objective();
let dim = obj.term.beta_dim();
let wrong: Array1<f64> = Array1::zeros(dim + 1);
let err = obj
.seed_inner_state(&wrong)
.expect_err("a populated β of the wrong length must be rejected");
match err {
EstimationError::RemlOptimizationFailed(msg) => {
assert!(
msg.contains("decoder dim"),
"error must name the decoder-dim mismatch; got: {msg}"
);
}
other => panic!("expected RemlOptimizationFailed, got {other:?}"),
}
}
fn intrinsic_test_atom(jacobian_scale: f64) -> SaeManifoldAtom {
let m = 5usize;
let n = m;
let p = 1usize;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, 1));
let mut decoder = Array2::<f64>::zeros((m, p));
for mu in 0..m {
phi[[mu, mu]] = 1.0;
jet[[mu, mu, 0]] = jacobian_scale * (1.0 + mu as f64);
decoder[[mu, 0]] = 1.0;
}
let s_raw = crate::basis::create_difference_penalty_matrix(m, 2, None).unwrap();
SaeManifoldAtom::new(
"intrinsic-1d",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
s_raw,
)
.unwrap()
}
#[test]
fn intrinsic_penalty_recovers_order_two_from_nullity() {
let atom = intrinsic_test_atom(1.0);
assert_eq!(atom.smooth_penalty_order, 2);
}
#[test]
fn line_search_snapshot_restores_intrinsic_smooth_penalty() {
let atom = intrinsic_test_atom(1.0);
let n = atom.n_obs();
let logits = Array2::<f64>::zeros((n, 1));
let coords = vec![Array2::<f64>::zeros((n, 1))];
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coords,
vec![LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)
.unwrap();
let mut term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let original = term.atoms[0].smooth_penalty.clone();
let snapshot = term.snapshot_mutable_state();
term.atoms[0].decoder_coefficients[[0, 0]] *= 3.0;
term.atoms[0].refresh_intrinsic_smooth_penalty();
let changed = (&term.atoms[0].smooth_penalty - &original)
.mapv(f64::abs)
.sum();
assert!(
changed > 1e-6,
"test setup must perturb the live intrinsic smoothness Gram"
);
term.restore_mutable_state(&snapshot);
let restored = (&term.atoms[0].smooth_penalty - &original)
.mapv(f64::abs)
.sum();
assert!(
restored < 1e-12,
"line-search restore left a stale intrinsic smoothness Gram: {restored}"
);
}
#[test]
fn intrinsic_penalty_is_invariant_to_speed_rescaling() {
let a1 = intrinsic_test_atom(1.0);
let a2 = intrinsic_test_atom(7.5);
assert_abs_diff_eq!(
(&a1.smooth_penalty_raw - &a2.smooth_penalty_raw)
.mapv(f64::abs)
.sum(),
0.0,
epsilon = 1e-12
);
let diff = (&a1.smooth_penalty - &a2.smooth_penalty)
.mapv(f64::abs)
.sum();
assert!(
diff < 1e-9,
"intrinsic Gram changed under a global speed rescale (gauge leak): {diff}"
);
}
#[test]
fn intrinsic_penalty_differs_from_raw_under_varying_speed() {
let atom = intrinsic_test_atom(1.0);
let diff = (&atom.smooth_penalty - &atom.smooth_penalty_raw)
.mapv(f64::abs)
.sum();
assert!(
diff > 1e-6,
"intrinsic reweighting was a no-op on a non-constant-speed curve: {diff}"
);
for i in 0..atom.basis_size() {
for j in 0..atom.basis_size() {
assert_abs_diff_eq!(
atom.smooth_penalty[[i, j]],
atom.smooth_penalty[[j, i]],
epsilon = 1e-12
);
}
}
}
#[test]
fn intrinsic_penalty_leaves_constant_speed_atom_unchanged() {
let m = 6usize;
let n = m;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, 1));
let mut decoder = Array2::<f64>::zeros((m, 1));
for mu in 0..m {
phi[[mu, mu]] = 1.0;
jet[[mu, mu, 0]] = 2.0;
decoder[[mu, 0]] = 1.0;
}
let s_raw = crate::basis::create_difference_penalty_matrix(m, 2, None).unwrap();
let atom = SaeManifoldAtom::new(
"constant-speed",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
s_raw,
)
.unwrap();
let diff = (&atom.smooth_penalty - &atom.smooth_penalty_raw)
.mapv(f64::abs)
.sum();
assert!(
diff < 1e-9,
"constant-speed atom's penalty was reweighted (should be identity): {diff}"
);
}
#[test]
fn pca_seed_handles_huge_equal_finite_columns_without_mean_overflow() {
let z = array![[1.0e308_f64, 1.0e308], [1.0e308, 1.0e308]];
let coords =
sae_pca_seed_initial_coords(z.view(), &[SaeAtomBasisKind::Periodic], &[1]).unwrap();
assert_eq!(coords.dim(), (1, 2, 1));
assert!(
coords.iter().all(|value| value.is_finite()),
"huge finite equal columns must not overflow the PCA seed mean: {coords:?}"
);
}
#[test]
fn pca_seed_rejects_huge_finite_span_that_overflows_centering() {
let z = array![[1.0e308_f64, 0.0], [-1.0e308, 0.0]];
let err = sae_pca_seed_initial_coords(z.view(), &[SaeAtomBasisKind::Periodic], &[1])
.expect_err("opposite huge finite values exceed f64 centering range");
assert!(
err.contains("centered Z is non-finite") || err.contains("SVD failed"),
"unexpected PCA seed error: {err}"
);
}
#[test]
fn planted_low_rank_frame_recovered_by_polar() {
let p = 12usize;
let r = 3usize;
let n = 200usize;
let mut planted = Array2::<f64>::zeros((p, r));
for j in 0..r {
planted[[j, j]] = 1.0;
}
let mut coords = Array2::<f64>::zeros((n, r));
for i in 0..n {
for j in 0..r {
let x = ((i * 7 + j * 13 + 1) % 97) as f64 / 97.0 - 0.5;
coords[[i, j]] = x;
}
}
let targets = fast_abt(&coords, &planted);
let angle =
grassmann_recover_planted_span_angle(targets.view(), coords.view(), planted.view())
.expect("span recovery");
assert_abs_diff_eq!(angle, 0.0, epsilon = 1.0e-9);
let frame = GrassmannFrame::polar_update(planted.view()).expect("polar");
let recovered_angle = frame
.max_principal_angle(planted.view())
.expect("principal angle");
assert_abs_diff_eq!(recovered_angle, 0.0, epsilon = 1.0e-9);
let gram = fast_atb(&frame.frame().to_owned(), &frame.frame().to_owned());
for i in 0..r {
for j in 0..r {
let expect = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(gram[[i, j]], expect, epsilon = 1.0e-9);
}
}
}
#[test]
fn factored_border_dim_invariant_and_reconstruction() {
let m = 6usize;
let p = 16usize;
let r = 2usize;
let mut frame = Array2::<f64>::zeros((p, r));
frame[[0, 0]] = 1.0;
frame[[1, 1]] = 1.0;
let mut c0 = Array2::<f64>::zeros((m, r));
for mu in 0..m {
c0[[mu, 0]] = 1.0 + mu as f64;
c0[[mu, 1]] = 0.5 * mu as f64 - 1.0;
}
let decoder = fast_abt(&c0, &frame);
let mut phi = Array2::<f64>::zeros((m, m));
let mut jet = Array3::<f64>::zeros((m, m, 1));
for mu in 0..m {
phi[[mu, mu]] = 1.0;
jet[[mu, mu, 0]] = 1.0;
}
let s_raw = crate::basis::create_difference_penalty_matrix(m, 2, None).unwrap();
let mut atom = SaeManifoldAtom::new(
"lowrank",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder.clone(),
s_raw,
)
.unwrap();
let activated = atom.maybe_activate_decoder_frame().expect("activate");
assert_eq!(
activated,
Some(r),
"rank-{r} decoder should profile to r={r}"
);
assert_eq!(atom.border_frame_rank(), r);
assert_eq!(atom.frame_manifold_dimension(), r * (p - r));
let coords = atom.factored_coordinates().unwrap().expect("coords");
assert_eq!(coords.dim(), (m, r));
let reconstructed = atom
.reconstruct_decoder_coefficients(coords.view())
.unwrap();
for mu in 0..m {
for j in 0..p {
assert_abs_diff_eq!(reconstructed[[mu, j]], decoder[[mu, j]], epsilon = 1.0e-9);
}
}
let term = SaeManifoldTerm::new(
vec![atom],
SaeAssignment::from_blocks_with_mode(
Array2::<f64>::zeros((m, 1)),
vec![Array2::<f64>::zeros((m, 1))],
AssignmentMode::softmax(0.7),
)
.unwrap(),
)
.unwrap();
grassmann_assert_border_dim_invariant(&term).expect("border invariant");
assert_eq!(term.factored_border_dim(), m * r);
assert_eq!(term.grassmann_evidence_dimension(), r * (p - r));
let mut term = term;
let border = term.flatten_factored_border().unwrap();
assert_eq!(border.len(), m * r);
let saved = term.atoms[0].decoder_coefficients.clone();
term.scatter_factored_border(border.view()).unwrap();
for mu in 0..m {
for j in 0..p {
assert_abs_diff_eq!(
term.atoms[0].decoder_coefficients[[mu, j]],
saved[[mu, j]],
epsilon = 1.0e-9
);
}
}
}
#[test]
fn factored_beta_penalty_probing_matches_projected_dense_curvature() {
let k_atoms = 2usize;
let m = 4usize;
let p = 24usize;
let r = 2usize;
let n_obs = 5usize;
let mut atoms = Vec::with_capacity(k_atoms);
let mut coord_blocks = Vec::with_capacity(k_atoms);
for atom_idx in 0..k_atoms {
let mut frame = Array2::<f64>::zeros((p, r));
frame[[atom_idx * r, 0]] = 1.0;
frame[[atom_idx * r + 1, 1]] = 1.0;
let mut coords = Array2::<f64>::zeros((n_obs, 1));
for row in 0..n_obs {
coords[[row, 0]] = row as f64;
}
let mut phi = Array2::<f64>::zeros((n_obs, m));
let mut jet = Array3::<f64>::zeros((n_obs, m, 1));
for row in 0..n_obs {
for basis_col in 0..m {
let x = (row + 1) as f64 * (basis_col + 1) as f64;
phi[[row, basis_col]] = 0.05 * x + if row == basis_col { 1.0 } else { 0.0 };
jet[[row, basis_col, 0]] = 0.01 * x;
}
}
let mut c = Array2::<f64>::zeros((m, r));
for basis_col in 0..m {
c[[basis_col, 0]] = 0.3 + 0.07 * (basis_col + atom_idx) as f64;
c[[basis_col, 1]] = -0.2 + 0.05 * (basis_col * 2 + atom_idx) as f64;
}
let decoder = fast_abt(&c, &frame);
let mut atom = SaeManifoldAtom::new(
"factored_probe",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap();
atom.maybe_activate_decoder_frame()
.expect("frame activation")
.expect("rank-2 atom should activate a frame");
atoms.push(atom);
coord_blocks.push(coords);
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::from_elem((n_obs, k_atoms), 0.25),
coord_blocks,
vec![LatentManifold::Euclidean, LatentManifold::Euclidean],
AssignmentMode::softmax(1.0),
)
.unwrap();
let term = SaeManifoldTerm::new(atoms, assignment).unwrap();
assert!(term.frames_active());
assert_eq!(term.factored_border_dim(), k_atoms * m * r);
let beta_len = term.beta_dim();
let mut registry = AnalyticPenaltyRegistry::new();
let nuclear = NuclearNormPenalty::new(
PsiSlice {
range: 0..beta_len,
latent_dim: Some(beta_len / p),
},
0.7,
p,
1.0e-4,
None,
false,
)
.unwrap();
registry.push(AnalyticPenaltyKind::NuclearNorm(Arc::new(nuclear)));
let incoherence = DecoderIncoherencePenalty::new(
PsiSlice {
range: 0..beta_len,
latent_dim: Some(beta_len / p),
},
vec![m, m],
p,
Array2::<f64>::from_elem((k_atoms, k_atoms), 0.5),
0.6,
false,
)
.unwrap();
registry.push(AnalyticPenaltyKind::DecoderIncoherence(Arc::new(
incoherence,
)));
let mut dense_sys = ArrowSchurSystem::new(0, 0, beta_len);
let dense_assembly = term
.add_sae_analytic_penalty_contributions(
&mut dense_sys,
®istry,
1.0,
None,
true,
None,
)
.unwrap();
assert!(dense_assembly.dense_written);
assert!(!dense_assembly.deferred_factored);
let projection = FrameProjection::new(&term);
let border_dim = term.factored_border_dim();
let projected = term.project_dense_penalty_to_factored(dense_sys.hbb.view(), &projection);
let direct = term.build_factored_beta_penalty_curvature(®istry, 1.0, &projection);
for row in 0..border_dim {
for col in 0..border_dim {
assert_abs_diff_eq!(direct[[row, col]], projected[[row, col]], epsilon = 1.0e-10);
}
}
let mut deferred_term = term.clone();
let rho = SaeManifoldRho::new(
0.0,
-20.0,
vec![Array1::<f64>::zeros(1), Array1::<f64>::zeros(1)],
);
let target = Array2::<f64>::zeros((n_obs, p));
let sys = deferred_term
.assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
target.view(),
&rho,
Some(®istry),
1.0,
1,
)
.unwrap();
assert_eq!(sys.k, border_dim);
assert!(sys.hbb.is_empty());
}
fn materialize_row_htbeta_for_test(sys: &ArrowSchurSystem, row_idx: usize) -> Array2<f64> {
let di = sys.row_dims[row_idx];
let k = sys.k;
let row = &sys.rows[row_idx];
let use_dense = sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none();
let mut out = if use_dense && row.htbeta.dim() == (di, k) {
row.htbeta.clone()
} else {
Array2::<f64>::zeros((di, k))
};
if let Some(op) = sys.htbeta_matvec.as_ref() {
let mut basis = Array1::<f64>::zeros(k);
let mut col = Array1::<f64>::zeros(di);
for beta_col in 0..k {
basis.fill(0.0);
basis[beta_col] = 1.0;
col.fill(0.0);
op(row_idx, basis.view(), &mut col);
for row_col in 0..di {
out[[row_col, beta_col]] += col[row_col];
}
}
}
out
}
fn project_row_htbeta_to_factored_for_test(
term: &SaeManifoldTerm,
htbeta_b: ArrayView2<'_, f64>,
) -> Array2<f64> {
FrameProjection::new(term).project_rows(htbeta_b)
}
#[test]
fn factored_row_htbeta_native_solve_matches_full_b_then_project() {
let k_atoms = 2usize;
let m = 4usize;
let p = 24usize;
let r = 2usize;
let n_obs = 5usize;
let mut atoms = Vec::with_capacity(k_atoms);
let mut coord_blocks = Vec::with_capacity(k_atoms);
for atom_idx in 0..k_atoms {
let mut frame = Array2::<f64>::zeros((p, r));
frame[[atom_idx * r, 0]] = 1.0;
frame[[atom_idx * r + 1, 1]] = 1.0;
let coords = Array2::from_shape_fn((n_obs, 1), |(row, _)| 0.1 * (row + 1) as f64);
let mut phi = Array2::<f64>::zeros((n_obs, m));
let mut jet = Array3::<f64>::zeros((n_obs, m, 1));
for row in 0..n_obs {
for basis_col in 0..m {
let x = (row + 1) as f64 * (basis_col + 1) as f64;
phi[[row, basis_col]] = 0.03 * x + if row % m == basis_col { 1.0 } else { 0.0 };
jet[[row, basis_col, 0]] = 0.02 * x;
}
}
let c = Array2::from_shape_fn((m, r), |(basis_col, frame_col)| {
0.2 + 0.04 * (basis_col + 2 * frame_col + atom_idx) as f64
});
let decoder = fast_abt(&c, &frame);
let mut atom = SaeManifoldAtom::new(
"factored_row_native",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap();
atom.maybe_activate_decoder_frame()
.expect("frame activation")
.expect("rank-2 atom should activate a frame");
atoms.push(atom);
coord_blocks.push(coords);
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
Array2::<f64>::from_shape_fn((n_obs, k_atoms), |(row, atom)| {
0.15 * (row + 1) as f64 - 0.07 * atom as f64
}),
coord_blocks,
vec![LatentManifold::Euclidean, LatentManifold::Euclidean],
AssignmentMode::softmax(0.9),
)
.unwrap();
let mut factored_term = SaeManifoldTerm::new(atoms, assignment).unwrap();
assert!(factored_term.frames_active());
let border_dim = factored_term.factored_border_dim();
assert!(border_dim < factored_term.beta_dim());
let mut full_term = factored_term.clone();
for atom in &mut full_term.atoms {
atom.deactivate_decoder_frame();
}
let rho = SaeManifoldRho::new(
0.0,
-0.2,
vec![Array1::<f64>::zeros(1), Array1::<f64>::zeros(1)],
);
let target = Array2::<f64>::from_shape_fn((n_obs, p), |(row, col)| {
0.01 * (row + 1) as f64 - 0.002 * (col + 1) as f64
});
let native_sys = factored_term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
assert_eq!(native_sys.k, border_dim);
for row in &native_sys.rows {
assert_eq!(row.htbeta.ncols(), border_dim);
}
let full_sys = full_term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
let mut projected_sys = factored_term
.assemble_arrow_schur(target.view(), &rho, None)
.unwrap();
projected_sys.htbeta_matvec = None;
projected_sys.htbeta_transpose_matvec = None;
projected_sys.htbeta_dense_supplement = false;
for row_idx in 0..n_obs {
let htbeta_b = materialize_row_htbeta_for_test(&full_sys, row_idx);
projected_sys.rows[row_idx].htbeta =
project_row_htbeta_to_factored_for_test(&factored_term, htbeta_b.view());
}
projected_sys.refresh_row_hessian_fingerprint();
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let (native_dt, native_db, _) =
solve_arrow_newton_step_with_options(&native_sys, 1.0e-8, 1.0e-8, &options).unwrap();
let (projected_dt, projected_db, _) =
solve_arrow_newton_step_with_options(&projected_sys, 1.0e-8, 1.0e-8, &options).unwrap();
assert_eq!(native_dt.len(), projected_dt.len());
assert_eq!(native_db.len(), projected_db.len());
for idx in 0..native_dt.len() {
assert_abs_diff_eq!(native_dt[idx], projected_dt[idx], epsilon = 1.0e-10);
}
for idx in 0..native_db.len() {
assert_abs_diff_eq!(native_db[idx], projected_db[idx], epsilon = 1.0e-10);
}
}
#[test]
fn factored_evidence_matches_full_b_at_small_p() {
let m = 5usize;
let p = 2usize;
let mut decoder = Array2::<f64>::zeros((m, p));
for mu in 0..m {
decoder[[mu, 0]] = 1.0 + mu as f64;
decoder[[mu, 1]] = (mu as f64) - 2.0;
}
let mut phi = Array2::<f64>::zeros((m, m));
let mut jet = Array3::<f64>::zeros((m, m, 1));
for mu in 0..m {
phi[[mu, mu]] = 1.0;
jet[[mu, mu, 0]] = 1.0;
}
let s_raw = crate::basis::create_difference_penalty_matrix(m, 2, None).unwrap();
let mut atom = SaeManifoldAtom::new(
"fullrank",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
s_raw,
)
.unwrap();
let activated = atom.maybe_activate_decoder_frame().expect("activate");
assert_eq!(
activated, None,
"full-rank small-p must stay on full-B path"
);
assert!(atom.decoder_frame.is_none());
assert_eq!(atom.border_frame_rank(), p);
assert_eq!(atom.frame_manifold_dimension(), 0);
let mut term = SaeManifoldTerm::new(
vec![atom],
SaeAssignment::from_blocks_with_mode(
Array2::<f64>::zeros((m, 1)),
vec![Array2::<f64>::zeros((m, 1))],
AssignmentMode::softmax(0.7),
)
.unwrap(),
)
.unwrap();
assert!(!term.frames_active());
assert_eq!(term.factored_border_dim(), term.beta_dim());
assert_eq!(term.grassmann_evidence_dimension(), 0);
let activated_n = term.auto_activate_decoder_frames().expect("auto");
assert_eq!(activated_n, 0, "small-p auto-activation must be a no-op");
let rho = SaeManifoldRho::new(0.0, 0.37, vec![array![0.0_f64]]);
let occam = term.reml_occam_term(&rho).expect("occam");
let rank_s = SaeManifoldTerm::symmetric_rank(&term.atoms[0].smooth_penalty).unwrap();
let expected = 0.5 * (p as f64) * (rank_s as f64) * rho.log_lambda_smooth;
assert_abs_diff_eq!(occam, expected, epsilon = 1.0e-12);
}
#[test]
fn streaming_polar_refresh_reorients_frame() {
let m = 4usize;
let p = 8usize;
let r = 2usize;
let mut frame0 = Array2::<f64>::zeros((p, r));
frame0[[0, 0]] = 1.0;
frame0[[1, 1]] = 1.0;
let mut c0 = Array2::<f64>::zeros((m, r));
for mu in 0..m {
c0[[mu, 0]] = 1.0 + mu as f64;
c0[[mu, 1]] = 0.5 - mu as f64;
}
let decoder = fast_abt(&c0, &frame0);
let mut phi = Array2::<f64>::zeros((m, m));
let mut jet = Array3::<f64>::zeros((m, m, 1));
for mu in 0..m {
phi[[mu, mu]] = 1.0;
jet[[mu, mu, 0]] = 1.0;
}
let s_raw = crate::basis::create_difference_penalty_matrix(m, 2, None).unwrap();
let mut atom = SaeManifoldAtom::new(
"stream",
SaeAtomBasisKind::EuclideanPatch,
1,
phi,
jet,
decoder,
s_raw,
)
.unwrap();
atom.maybe_activate_decoder_frame().expect("activate");
let mut cross = Array2::<f64>::zeros((p, r));
cross[[2, 0]] = 3.0;
cross[[3, 1]] = 2.0;
atom.refresh_frame_from_cross_moment(cross.view())
.expect("refresh");
let frame = atom.decoder_frame.as_ref().expect("frame");
let gram = fast_atb(&frame.frame().to_owned(), &frame.frame().to_owned());
for i in 0..r {
for j in 0..r {
let expect = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(gram[[i, j]], expect, epsilon = 1.0e-9);
}
}
let mut target_span = Array2::<f64>::zeros((p, r));
target_span[[2, 0]] = 1.0;
target_span[[3, 1]] = 1.0;
let angle = frame
.max_principal_angle(target_span.view())
.expect("angle");
assert_abs_diff_eq!(angle, 0.0, epsilon = 1.0e-9);
}
fn gamma_fd_tiny_fixture() -> (SaeManifoldTerm, Array2<f64>, SaeManifoldRho) {
let n = 10usize;
let p = 3usize;
let k_atoms = 2usize;
let m = 3usize;
let evaluator = Arc::new(PeriodicHarmonicEvaluator::new(m).unwrap());
let mut logits = Array2::<f64>::zeros((n, k_atoms));
let mut coords = vec![Array2::<f64>::zeros((n, 1)), Array2::<f64>::zeros((n, 1))];
let weights = [
[
[0.10, -0.05, 0.03],
[0.35, -0.20, 0.12],
[-0.16, 0.18, 0.08],
],
[
[-0.08, 0.04, 0.06],
[0.22, 0.10, -0.18],
[0.11, -0.24, 0.15],
],
];
let mut target = Array2::<f64>::zeros((n, p));
for row in 0..n {
let phase = (row as f64 + 0.35) / n as f64;
coords[0][[row, 0]] = phase;
coords[1][[row, 0]] = (phase + 0.21).fract();
logits[[row, 0]] = if row % 2 == 0 { 0.8 } else { -0.6 };
let assignments = softmax_row(logits.row(row), 0.9);
for atom in 0..k_atoms {
let theta = std::f64::consts::TAU * coords[atom][[row, 0]];
let basis = [1.0, theta.sin(), theta.cos()];
for out_col in 0..p {
for basis_col in 0..m {
target[[row, out_col]] += assignments[atom]
* basis[basis_col]
* weights[atom][basis_col][out_col];
}
}
}
}
let mut atoms = Vec::with_capacity(k_atoms);
for atom in 0..k_atoms {
let (phi, jet) = evaluator.evaluate(coords[atom].view()).unwrap();
let decoder = Array2::from_shape_fn((m, p), |(basis_col, out_col)| {
weights[atom][basis_col][out_col]
});
atoms.push(
SaeManifoldAtom::new(
format!("gamma_{atom}"),
SaeAtomBasisKind::Periodic,
1,
phi,
jet,
decoder,
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_second_jet(evaluator.clone()),
);
}
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
coords,
vec![LatentManifold::Circle { period: 1.0 }; k_atoms],
AssignmentMode::softmax(0.9),
)
.unwrap();
let term = SaeManifoldTerm::new(atoms, assignment).unwrap();
let rho = SaeManifoldRho::new(
-6.0,
-6.0,
vec![Array1::from_vec(vec![-6.0]), Array1::from_vec(vec![-6.0])],
);
(term, target, rho)
}
fn fixed_state_logdet(
mut term: SaeManifoldTerm,
target: &Array2<f64>,
rho: &SaeManifoldRho,
) -> f64 {
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), rho, None, 0, 0.4, 1.0e-6, 1.0e-6)
.expect("fixed-state cache");
let (tt, beta) = cache.arrow_log_det();
tt + beta.expect("dense Schur logdet")
}
#[test]
fn sae_logdet_theta_adjoint_matches_dense_fd_on_tiny_fixture() {
let (mut term, target, rho) = gamma_fd_tiny_fixture();
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 5, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
let gamma = term.logdet_theta_adjoint(&rho, &cache).expect("Gamma");
let h = 1.0e-5;
let probes = [
(0usize, 0usize, SaeLocalRowVar::Logit { atom: 0 }),
(3usize, 1usize, SaeLocalRowVar::Coord { atom: 0, axis: 0 }),
];
for (row, local_pos, var) in probes {
let mut plus = term.clone();
let mut minus = term.clone();
match var {
SaeLocalRowVar::Logit { atom } => {
plus.assignment.logits[[row, atom]] += h;
minus.assignment.logits[[row, atom]] -= h;
}
SaeLocalRowVar::Coord { atom, axis } => {
let mut flat_p = plus.assignment.coords[atom].as_flat().clone();
let mut flat_m = minus.assignment.coords[atom].as_flat().clone();
let idx = row * plus.assignment.coords[atom].latent_dim() + axis;
flat_p[idx] += h;
flat_m[idx] -= h;
plus.assignment.coords[atom].set_flat(flat_p.view());
minus.assignment.coords[atom].set_flat(flat_m.view());
}
}
let fd = (fixed_state_logdet(plus, &target, &rho)
- fixed_state_logdet(minus, &target, &rho))
/ (2.0 * h);
let analytic = gamma.t[cache.row_offsets[row] + local_pos];
let tol = 2.0e-3 * (1.0 + fd.abs().max(analytic.abs()));
assert!(
(fd - analytic).abs() <= tol,
"Gamma row={row} local_pos={local_pos}: fd={fd:.8e}, analytic={analytic:.8e}"
);
}
}
#[test]
fn sae_logdet_theta_adjoint_matches_dense_fd_ibp_map() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, false);
rho.log_lambda_sparse = -1.0;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 5, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
let gamma = term.logdet_theta_adjoint(&rho, &cache).expect("Gamma");
let h = 1.0e-5;
let probes = [
(0usize, 0usize, SaeLocalRowVar::Logit { atom: 0 }),
(4usize, 1usize, SaeLocalRowVar::Logit { atom: 1 }),
(7usize, 0usize, SaeLocalRowVar::Logit { atom: 0 }),
];
for (row, local_pos, var) in probes {
let mut plus = term.clone();
let mut minus = term.clone();
match var {
SaeLocalRowVar::Logit { atom } => {
plus.assignment.logits[[row, atom]] += h;
minus.assignment.logits[[row, atom]] -= h;
}
SaeLocalRowVar::Coord { atom, axis } => {
let mut flat_p = plus.assignment.coords[atom].as_flat().clone();
let mut flat_m = minus.assignment.coords[atom].as_flat().clone();
let idx = row * plus.assignment.coords[atom].latent_dim() + axis;
flat_p[idx] += h;
flat_m[idx] -= h;
plus.assignment.coords[atom].set_flat(flat_p.view());
minus.assignment.coords[atom].set_flat(flat_m.view());
}
}
let fd = (fixed_state_logdet(plus, &target, &rho)
- fixed_state_logdet(minus, &target, &rho))
/ (2.0 * h);
let analytic = gamma.t[cache.row_offsets[row] + local_pos];
let tol = 3.0e-3 * (1.0 + fd.abs().max(analytic.abs()));
assert!(
(fd - analytic).abs() <= tol,
"IBP Gamma row={row} local_pos={local_pos}: fd={fd:.8e}, analytic={analytic:.8e}"
);
}
}
#[test]
fn ibp_map_outer_objective_advertises_analytic_gradient() {
let (mut term, target, rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.9, 1.0, false);
let obj = SaeManifoldOuterObjective::new(term, target, None, rho, 5, 0.4, 1.0e-6, 1.0e-6);
assert_eq!(obj.capability().gradient, Derivative::Analytic);
}
}
pub fn sae_pca_seed_initial_coords(
z: ArrayView2<'_, f64>,
basis_kinds: &[SaeAtomBasisKind],
atom_dim: &[usize],
) -> Result<Array3<f64>, String> {
let k_atoms = basis_kinds.len();
let (n_obs, _p_out) = z.dim();
let d_max = atom_dim.iter().copied().max().unwrap_or(1).max(1);
let mut out = Array3::<f64>::zeros((k_atoms, n_obs, d_max));
if n_obs == 0 || z.ncols() == 0 {
return Ok(out);
}
for ((row, col), &value) in z.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"sae_pca_seed: Z must be finite; Z[{row}, {col}] = {value}"
));
}
}
let mut col_means = Array1::<f64>::zeros(z.ncols());
for col in 0..z.ncols() {
let mut mean = 0.0_f64;
for (count, row) in (0..n_obs).enumerate() {
let x = z[[row, col]];
mean += (x - mean) / (count as f64 + 1.0);
}
col_means[col] = mean;
}
let mut centered = z.to_owned();
for row in 0..n_obs {
for col in 0..z.ncols() {
centered[[row, col]] -= col_means[col];
}
}
for ((row, col), &value) in centered.indexed_iter() {
if !value.is_finite() {
return Err(format!(
"sae_pca_seed: centered Z is non-finite at [{row}, {col}] \
(data span exceeds f64 range); rescale Z before seeding"
));
}
}
let (u_opt, s_vals, vt_opt) = centered
.svd(true, true)
.map_err(|err| format!("sae_pca_seed: SVD failed: {err:?}"))?;
let u = u_opt.ok_or_else(|| "sae_pca_seed: SVD returned no U".to_string())?;
let vt = vt_opt.ok_or_else(|| "sae_pca_seed: SVD returned no Vt".to_string())?;
let vt_rows = vt.nrows();
let u_cols = u.ncols();
let two_pi = std::f64::consts::TAU;
for atom_idx in 0..k_atoms {
let d = atom_dim[atom_idx];
if d == 0 {
continue;
}
match &basis_kinds[atom_idx] {
SaeAtomBasisKind::Periodic => {
if vt_rows >= 2 {
let pc_pairs = vt_rows / 2;
let (pc1_row, pc2_row) = if pc_pairs >= 1 {
let pair = if pc_pairs > 0 { atom_idx % pc_pairs } else { 0 };
(2 * pair, 2 * pair + 1)
} else {
(0, 1)
};
let pc1 = vt.row(pc1_row.min(vt_rows - 1));
let pc2 = vt.row(pc2_row.min(vt_rows - 1));
for row in 0..n_obs {
let mut a = 0.0_f64;
let mut b = 0.0_f64;
for col in 0..centered.ncols() {
a += centered[[row, col]] * pc1[col];
b += centered[[row, col]] * pc2[col];
}
let phase = b.atan2(a) / two_pi;
out[[atom_idx, row, 0]] = phase - phase.floor();
}
}
for axis in 1..d {
if axis >= vt_rows {
break;
}
let pc = vt.row(axis);
let mut proj = Array1::<f64>::zeros(n_obs);
for row in 0..n_obs {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc[col];
}
proj[row] = acc;
}
let (min_v, max_v) = proj
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &v| {
(lo.min(v), hi.max(v))
});
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
out[[atom_idx, row, axis]] = (proj[row] - min_v) / span - 0.5;
}
}
}
}
SaeAtomBasisKind::Sphere => {
let n_pc = vt_rows.min(3);
if n_pc == 0 {
continue;
}
let pcs: Vec<_> = (0..n_pc).map(|i| vt.row(i)).collect();
for row in 0..n_obs {
let mut amb = [0.0_f64; 3];
for (i, pc) in pcs.iter().enumerate() {
let mut acc = 0.0_f64;
for col in 0..centered.ncols() {
acc += centered[[row, col]] * pc[col];
}
amb[i] = acc;
}
let norm = (amb[0] * amb[0] + amb[1] * amb[1] + amb[2] * amb[2]).sqrt();
let (x, y, z) = if norm > 0.0 {
(amb[0] / norm, amb[1] / norm, amb[2] / norm)
} else {
(1.0, 0.0, 0.0)
};
let lat = z.clamp(-1.0, 1.0).asin();
let lon = y.atan2(x);
if d >= 1 {
out[[atom_idx, row, 0]] = lat;
}
if d >= 2 {
out[[atom_idx, row, 1]] = lon;
}
}
}
SaeAtomBasisKind::Torus => {
for axis in 0..d {
let pc_a_idx = 2 * axis;
let pc_b_idx = 2 * axis + 1;
if pc_b_idx >= vt_rows {
break;
}
let pc_a = vt.row(pc_a_idx);
let pc_b = vt.row(pc_b_idx);
for row in 0..n_obs {
let mut a = 0.0_f64;
let mut b = 0.0_f64;
for col in 0..centered.ncols() {
a += centered[[row, col]] * pc_a[col];
b += centered[[row, col]] * pc_b[col];
}
let phase = b.atan2(a) / two_pi;
let wrapped = phase - phase.floor();
out[[atom_idx, row, axis]] = wrapped;
}
}
}
_ => {
let k_cols = d.min(u_cols).min(s_vals.len());
let mut tmp = Array2::<f64>::zeros((n_obs, d));
for col in 0..k_cols {
let s_col = s_vals[col];
for row in 0..n_obs {
tmp[[row, col]] = u[[row, col]] * s_col;
}
}
for col in 0..d {
let mut min_v = f64::INFINITY;
let mut max_v = f64::NEG_INFINITY;
for row in 0..n_obs {
let v = tmp[[row, col]];
if v < min_v {
min_v = v;
}
if v > max_v {
max_v = v;
}
}
let span = max_v - min_v;
if span > 0.0 {
for row in 0..n_obs {
out[[atom_idx, row, col]] = (tmp[[row, col]] - min_v) / span - 0.5;
}
}
}
}
}
}
Ok(out)
}