use ndarray::{Array1, Array2, Array3, Array4, ArrayView1, ArrayView2, ArrayView3, ArrayView4, s};
use std::sync::Arc;
const SAE_BYTES_PER_F64: usize = 8;
const SAE_HOST_IN_CORE_FALLBACK_BYTES: usize = 2 * 1024 * 1024 * 1024;
const SAE_HOST_MEMORY_BUDGET_FRACTION_NUMERATOR: usize = 3;
const SAE_HOST_MEMORY_BUDGET_FRACTION_DENOMINATOR: usize = 5;
const SAE_CPU_L2_CACHE_BYTES: usize = 1024 * 1024;
const SAE_CHUNK_CACHE_MULTIPLE: usize = 8;
const SAE_MIN_STREAMING_CHUNK_ROWS: usize = 256;
const SAE_MATRIX_FREE_VECTOR_WORKSPACE_MULTIPLIER: usize = 32;
use crate::solver::arrow_schur::{
ArrowProximalCorrectionOptions, ArrowRowBlock, ArrowSchurError, ArrowSchurSystem,
ArrowSolveOptions, BetaPenaltyOp, CompositePenaltyOp, DensePenaltyOp, DeviceSaePcgData,
DeviceSaeSmoothBlock, FactoredFrameGBlock, FactoredFrameKroneckerOp, IbpCrossRowSource,
IdentityRightKroneckerPenaltyOp, SparseBlockKroneckerPenaltyOp, SparseGBlock,
StreamingArrowSchur, solve_arrow_newton_step_with_proximal_correction,
solve_streaming_reduced_beta, solve_with_lm_escalation_inner,
};
use crate::terms::analytic_penalties::{
AnalyticPenalty, AnalyticPenaltyKind, AnalyticPenaltyRegistry, DecoderIncoherencePenalty,
IbpHessianDiagThirdChannels, IsometryPenalty, MechanismSparsityPenalty, NuclearNormPenalty,
PenaltyTier, PsiSlice, WeightField, resolve_learnable_weight,
};
use crate::terms::latent_coord::{LatentCoordValues, LatentIdMode, LatentManifold};
use crate::terms::sae_criterion_atoms::SaeCriterion;
use crate::terms::sae_optimality_certificate::{
CriterionCertificate, DirectionalSamples, certificate_from_samples,
deterministic_probe_direction, probe_step,
};
use crate::linalg::faer_ndarray::{
FaerCholesky, FaerCholeskyFactor, FaerEigh, FaerSvd, fast_ab, fast_abt, fast_atb,
};
use crate::linalg::triangular::cholesky_solve_vector;
use crate::solver::arrow_schur::{
ArrowFactorCache, ArrowRowGaugeDeflation, arrow_factor_max_pivot, arrow_factor_min_pivot,
solve_arrow_newton_step_with_options,
};
use crate::solver::estimate::EstimationError;
use crate::solver::evidence::arrow_log_det_from_cache;
use crate::solver::outer_strategy::{
DeclaredHessianForm, Derivative, EfsEval, HessianResult, OuterCapability, OuterEval,
OuterEvalOrder, OuterObjective, SeedOutcome,
};
use crate::solver::structure_search::{CollapseAction, CollapseEvent};
use faer::Side;
const SAE_MANIFOLD_ARMIJO_C1: f64 = 1.0e-4;
const SAE_MANIFOLD_MAX_LINESEARCH_HALVINGS: usize = 12;
const SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR: f64 = 1.0e-12;
const SAE_OUTER_GRADIENT_GAUGE_RAYLEIGH_FACTOR: f64 = 1.0e-8;
const SAE_DECODER_BETA_NULL_RELATIVE_FLOOR: f64 = 1.0e-9;
const SAE_OUTER_GRADIENT_BETA_NULL_PROBE_MAX_DIM: usize = 512;
const CURVATURE_WALK_INITIAL_ETA_STEP: f64 = 0.2;
const CURVATURE_WALK_MIN_ETA_STEP: f64 = 1.0 / 256.0;
const CURVATURE_WALK_MAX_CORRECTORS: usize = 32;
const SAE_MANIFOLD_DIRECTIONAL_DECREASE_REL_FLOOR: f64 = 1.0e-14;
const SAE_LOSS_PARALLEL_ROW_MIN: usize = 64;
const SAE_MANIFOLD_INNER_STEP_REL_TOL: f64 = 1.0e-4;
const SAE_MANIFOLD_INNER_GRAD_REL_TOL: f64 = 1.0e-5;
const SAE_MANIFOLD_INNER_OBJECTIVE_STALL_REL_TOL: f64 = 1.0e-8;
const SAE_MANIFOLD_INNER_OBJECTIVE_STALL_FRACTION: f64 = 1.0e-4;
const SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS: usize = 3;
const SAE_DENSE_BETA_PENALTY_PROBE_MAX_DIM: usize = 4096;
const SAE_MANIFOLD_SPECTRAL_RANK_CUTOFF: f64 = 1.0e-9;
const SAE_MANIFOLD_ROW_RIDGE_FLOOR: f64 = 1.0e-12;
const SAE_MANIFOLD_ROW_RIDGE_GROWTH: f64 = 10.0;
const SAE_MANIFOLD_ROW_RIDGE_MAX_ATTEMPTS: usize = 12;
#[derive(Clone, Copy, Debug, Default)]
struct SaeBetaPenaltyAssembly {
dense_written: bool,
deferred_factored: bool,
}
impl SaeBetaPenaltyAssembly {
fn record_curvature(&mut self, dense_beta_curvature: bool) {
if dense_beta_curvature {
self.dense_written = true;
} else {
self.deferred_factored = true;
}
}
}
const SAE_FIT_DATA_COLLAPSE_EV_FLOOR: f64 = 0.10;
const SAE_FIT_DATA_COLLAPSE_COST: f64 = 1.0e12;
const SAE_PRISTINE_SEED_EV_RETAIN_FLOOR: f64 = 0.95;
const SAE_FINAL_EV_DEGRADATION_TOL: f64 = 1.0e-3;
const SAE_SEED_DISPERSION_FLOOR: f64 = 1.0e-12;
const JUMPRELU_REACTIVATION_MARGIN: f64 = 4.0;
#[inline]
fn jumprelu_in_optimization_band(logit: f64, threshold: f64, temperature: f64) -> bool {
logit > threshold - JUMPRELU_REACTIVATION_MARGIN * temperature
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SaeStreamingPlan {
pub streaming: bool,
pub chunk_size: usize,
pub estimated_full_batch_bytes: usize,
pub estimated_dense_schur_bytes: usize,
pub estimated_row_cross_bytes: usize,
pub estimated_direct_peak_bytes: usize,
pub estimated_matrix_free_peak_bytes: usize,
pub in_core_budget_bytes: usize,
pub host_available_bytes: usize,
pub direct_admitted: bool,
pub matrix_free_admitted: bool,
}
fn sae_streaming_plan_from_budget(
n_obs: usize,
total_basis: usize,
k_atoms: usize,
d_max: usize,
border_dim: usize,
in_core_budget_bytes: usize,
chunk_window_bytes: usize,
host_available_bytes: usize,
) -> SaeStreamingPlan {
let per_row_words = total_basis
.saturating_mul(1 + d_max)
.saturating_add(k_atoms)
.max(1);
let per_row_bytes = per_row_words.saturating_mul(SAE_BYTES_PER_F64);
let full_batch_bytes = n_obs.saturating_mul(per_row_bytes);
let dense_schur_bytes = border_dim
.saturating_mul(border_dim)
.saturating_mul(SAE_BYTES_PER_F64);
let row_block_dim = k_atoms.saturating_mul(1usize.saturating_add(d_max));
let row_cross_bytes = n_obs
.saturating_mul(row_block_dim)
.saturating_mul(border_dim)
.saturating_mul(SAE_BYTES_PER_F64);
let direct_peak_bytes = full_batch_bytes
.saturating_add(row_cross_bytes)
.saturating_add(dense_schur_bytes);
let matrix_free_peak_bytes = chunk_window_bytes
.min(full_batch_bytes.max(per_row_bytes))
.saturating_add(row_cross_bytes)
.saturating_add(
border_dim
.saturating_mul(SAE_BYTES_PER_F64)
.saturating_mul(SAE_MATRIX_FREE_VECTOR_WORKSPACE_MULTIPLIER),
);
let direct_admitted = direct_peak_bytes <= in_core_budget_bytes;
let matrix_free_admitted = matrix_free_peak_bytes <= in_core_budget_bytes;
let rows_per_chunk = (chunk_window_bytes / per_row_bytes).max(SAE_MIN_STREAMING_CHUNK_ROWS);
SaeStreamingPlan {
streaming: !direct_admitted,
chunk_size: if direct_admitted {
n_obs.max(1)
} else {
rows_per_chunk.min(n_obs).max(1)
},
estimated_full_batch_bytes: full_batch_bytes,
estimated_dense_schur_bytes: dense_schur_bytes,
estimated_row_cross_bytes: row_cross_bytes,
estimated_direct_peak_bytes: direct_peak_bytes,
estimated_matrix_free_peak_bytes: matrix_free_peak_bytes,
in_core_budget_bytes,
host_available_bytes,
direct_admitted,
matrix_free_admitted,
}
}
pub fn sae_streaming_plan_for_shape(
n_obs: usize,
total_basis: usize,
k_atoms: usize,
d_max: usize,
border_dim: usize,
) -> SaeStreamingPlan {
let (budget, chunk_window, host_available) = match crate::gpu::runtime::GpuRuntime::global() {
Some(rt) => {
let aggregate_budget: usize = rt
.device_ordinals()
.iter()
.map(|&ord| rt.memory_budget_for(ord))
.sum();
let per_device_budget = aggregate_budget / rt.device_count().max(1);
let window =
(per_device_budget / 16).max(SAE_CPU_L2_CACHE_BYTES * SAE_CHUNK_CACHE_MULTIPLE);
let host_available = sae_host_available_memory_bytes();
(
(aggregate_budget / 4).min(host_available),
window,
host_available,
)
}
None => {
let (budget, host_available) = sae_host_in_core_budget_bytes();
(
budget,
SAE_CPU_L2_CACHE_BYTES * SAE_CHUNK_CACHE_MULTIPLE,
host_available,
)
}
};
sae_streaming_plan_from_budget(
n_obs,
total_basis,
k_atoms,
d_max,
border_dim,
budget,
chunk_window,
host_available,
)
}
impl SaeStreamingPlan {
fn admitted_or_error(self, n: usize, p: usize, k_atoms: usize) -> Result<Self, String> {
if self.direct_admitted || self.matrix_free_admitted {
Ok(self)
} else {
Err(format!(
"SaeManifoldTerm::streaming_plan: predicted working set {} bytes exceeds budget {} bytes; shape n={n},p={p},K={k_atoms}",
self.estimated_matrix_free_peak_bytes, self.in_core_budget_bytes
))
}
}
fn solve_options_for_border_dim(self, border_dim: usize) -> ArrowSolveOptions {
if self.direct_admitted {
ArrowSolveOptions::automatic(border_dim)
} else {
ArrowSolveOptions::inexact_pcg()
}
}
fn direct_logdet_admitted(self) -> bool {
self.direct_admitted
}
}
fn sae_host_available_memory_bytes() -> usize {
let mut sys = sysinfo::System::new();
sys.refresh_memory();
let available = sys.available_memory() as usize;
if available == 0 {
SAE_HOST_IN_CORE_FALLBACK_BYTES
} else {
available
}
}
fn sae_host_in_core_budget_bytes() -> (usize, usize) {
let available = sae_host_available_memory_bytes();
let fraction = available.saturating_mul(SAE_HOST_MEMORY_BUDGET_FRACTION_NUMERATOR)
/ SAE_HOST_MEMORY_BUDGET_FRACTION_DENOMINATOR;
(fraction.max(SAE_HOST_IN_CORE_FALLBACK_BYTES), available)
}
#[derive(Debug, Clone)]
pub enum ScheduleKind {
Geometric { rate: f64 },
Linear { steps: usize },
ReciprocalIter,
}
#[derive(Debug, Clone)]
pub struct GumbelTemperatureSchedule {
pub tau_start: f64,
pub tau_min: f64,
pub decay: ScheduleKind,
pub iter_count: usize,
}
impl GumbelTemperatureSchedule {
#[must_use = "build error must be handled"]
pub fn new(tau_start: f64, tau_min: f64, decay: ScheduleKind) -> Result<Self, String> {
let sched = Self {
tau_start,
tau_min,
decay,
iter_count: 0,
};
sched.validate()?;
Ok(sched)
}
pub fn validate(&self) -> Result<(), String> {
if !(self.tau_start.is_finite() && self.tau_start > 0.0) {
return Err(format!(
"GumbelTemperatureSchedule: tau_start must be finite and positive; got {}",
self.tau_start
));
}
if !(self.tau_min.is_finite() && self.tau_min > 0.0) {
return Err(format!(
"GumbelTemperatureSchedule: tau_min must be finite and positive; got {}",
self.tau_min
));
}
if self.tau_min > self.tau_start {
return Err(format!(
"GumbelTemperatureSchedule: tau_min ({}) cannot exceed tau_start ({})",
self.tau_min, self.tau_start
));
}
match self.decay {
ScheduleKind::Geometric { rate } => {
if !(rate.is_finite() && rate > 0.0 && rate < 1.0) {
return Err(format!(
"GumbelTemperatureSchedule::Geometric: rate must be in (0, 1); got {rate}"
));
}
}
ScheduleKind::Linear { steps } => {
if steps == 0 {
return Err("GumbelTemperatureSchedule::Linear: steps must be positive".into());
}
}
ScheduleKind::ReciprocalIter => {}
}
Ok(())
}
pub fn current_tau(&self, iter: usize) -> f64 {
let raw = match self.decay {
ScheduleKind::Geometric { rate } => self.tau_start * rate.powf(iter as f64),
ScheduleKind::Linear { steps } => {
if iter >= steps {
self.tau_min
} else {
let frac = iter as f64 / steps as f64;
self.tau_start + frac * (self.tau_min - self.tau_start)
}
}
ScheduleKind::ReciprocalIter => self.tau_start / (1.0 + iter as f64),
};
raw.max(self.tau_min)
}
pub fn step(&mut self) -> f64 {
let tau = self.current_tau(self.iter_count);
self.iter_count += 1;
tau
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum SearchStrategy {
Fixed,
ExponentialSweep { values: Vec<f64> },
}
impl SearchStrategy {
#[must_use]
pub fn is_fixed(&self) -> bool {
matches!(self, Self::Fixed)
}
#[must_use]
pub fn sweep_values(&self) -> Option<&[f64]> {
match self {
Self::Fixed => None,
Self::ExponentialSweep { values } => Some(values),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SaeAtomBasisKind {
Duchon,
Periodic,
Sphere,
Torus,
EuclideanPatch,
Poincare,
Precomputed(String),
}
impl SaeAtomBasisKind {
fn latent_manifold(&self, latent_dim: usize) -> LatentManifold {
match self {
Self::Periodic => {
if latent_dim == 1 {
LatentManifold::Circle { period: 1.0 }
} else {
LatentManifold::Product(
(0..latent_dim)
.map(|_| LatentManifold::Circle { period: 1.0 })
.collect(),
)
}
}
Self::Sphere => LatentManifold::Product(vec![
LatentManifold::Interval {
lo: -std::f64::consts::FRAC_PI_2,
hi: std::f64::consts::FRAC_PI_2,
},
LatentManifold::Circle {
period: std::f64::consts::TAU,
},
]),
Self::Torus => {
if latent_dim == 1 {
LatentManifold::Circle { period: 1.0 }
} else {
LatentManifold::Product(
(0..latent_dim)
.map(|_| LatentManifold::Circle { period: 1.0 })
.collect(),
)
}
}
Self::Duchon
| Self::EuclideanPatch
| Self::Poincare
| Self::Precomputed(_) => LatentManifold::Euclidean,
}
}
fn projection_seed_grid(&self, latent_dim: usize, resolution: usize) -> Option<Array2<f64>> {
match self {
Self::Periodic => torus_projection_seed_grid(latent_dim, resolution),
Self::Sphere if latent_dim == 2 => sphere_projection_seed_grid(resolution),
Self::Sphere => None,
Self::Torus => torus_projection_seed_grid(latent_dim, resolution),
Self::Duchon | Self::EuclideanPatch | Self::Poincare | Self::Precomputed(_) => None,
}
}
}
fn sphere_projection_seed_grid(resolution: usize) -> Option<Array2<f64>> {
use std::f64::consts::PI;
let r = resolution.max(2);
let mut grid = Array2::<f64>::zeros((r * r, 2));
for i in 0..r {
let lat = -PI / 2.0 + PI * (i as f64 + 0.5) / r as f64;
for j in 0..r {
let lon = -PI + 2.0 * PI * (j as f64) / r as f64;
grid[[i * r + j, 0]] = lat;
grid[[i * r + j, 1]] = lon;
}
}
Some(grid)
}
fn torus_projection_seed_grid(latent_dim: usize, resolution: usize) -> Option<Array2<f64>> {
if latent_dim == 0 || latent_dim >= usize::BITS as usize {
return None;
}
const MAX_GRID_POINTS: usize = 4096;
let min_points = 1usize << latent_dim;
if min_points > MAX_GRID_POINTS {
return None;
}
let requested = resolution.max(2);
let mut per_axis = requested;
while per_axis.saturating_pow(latent_dim as u32) > MAX_GRID_POINTS {
per_axis -= 1;
if per_axis < 2 {
return None;
}
}
let total: usize = (0..latent_dim).fold(1usize, |acc, _| acc.saturating_mul(per_axis));
let mut grid = Array2::<f64>::zeros((total, latent_dim));
let mut idx = vec![0usize; latent_dim];
for flat in 0..total {
for axis in 0..latent_dim {
grid[[flat, axis]] = idx[axis] as f64 / per_axis as f64;
}
for axis in (0..latent_dim).rev() {
idx[axis] += 1;
if idx[axis] < per_axis {
break;
}
idx[axis] = 0;
}
}
Some(grid)
}
#[derive(Clone, Copy, Debug)]
struct ArdAxisPrior {
value: f64,
grad: f64,
hess: f64,
sq_equiv: f64,
}
impl ArdAxisPrior {
fn eval(alpha: f64, t: f64, period: Option<f64>) -> Self {
match period {
None => Self {
value: 0.5 * alpha * t * t,
grad: alpha * t,
hess: alpha,
sq_equiv: t * t,
},
Some(p) => {
let kappa = std::f64::consts::TAU / p;
let (sin, cos) = (kappa * t).sin_cos();
let one_minus_cos = 1.0 - cos;
Self {
value: (alpha / (kappa * kappa)) * one_minus_cos,
grad: (alpha / kappa) * sin,
hess: alpha * cos,
sq_equiv: (2.0 / (kappa * kappa)) * one_minus_cos,
}
}
}
}
}
fn bessel_i0_scaled_poly(ax: f64) -> f64 {
let y = 3.75 / ax;
0.39894228
+ y * (0.01328592
+ y * (0.00225319
+ y * (-0.00157565
+ y * (0.00916281
+ y * (-0.02057706
+ y * (0.02635537 + y * (-0.01647633 + y * 0.00392377)))))))
}
fn bessel_i1_scaled_poly(ax: f64) -> f64 {
let y = 3.75 / ax;
0.39894228
+ y * (-0.03988024
+ y * (-0.00362018
+ y * (0.00163801
+ y * (-0.01031555
+ y * (0.02282967
+ y * (-0.02895312 + y * (0.01787654 - y * 0.00420059)))))))
}
fn bessel_i0(x: f64) -> f64 {
let ax = x.abs();
if ax < 3.75 {
let t = x / 3.75;
let t2 = t * t;
1.0 + t2
* (3.5156229
+ t2 * (3.0899424
+ t2 * (1.2067492 + t2 * (0.2659732 + t2 * (0.0360768 + t2 * 0.0045813)))))
} else {
(ax.exp() / ax.sqrt()) * bessel_i0_scaled_poly(ax)
}
}
fn bessel_i1(x: f64) -> f64 {
let ax = x.abs();
let value = if ax < 3.75 {
let t = x / 3.75;
let t2 = t * t;
ax * (0.5
+ t2 * (0.87890594
+ t2 * (0.51498869
+ t2 * (0.15084934 + t2 * (0.02658733 + t2 * (0.00301532 + t2 * 0.00032411))))))
} else {
(ax.exp() / ax.sqrt()) * bessel_i1_scaled_poly(ax)
};
if x < 0.0 { -value } else { value }
}
fn bessel_i0_log_and_ratio(eta: f64) -> (f64, f64) {
let ax = eta.abs();
if ax < 3.75 {
let i0 = bessel_i0(ax);
let i1 = bessel_i1(ax);
(i0.ln(), i1 / i0)
} else {
let poly0 = bessel_i0_scaled_poly(ax);
let poly1 = bessel_i1_scaled_poly(ax);
let log_i0 = ax - 0.5 * ax.ln() + poly0.ln();
let ratio = poly1 / poly0;
(log_i0, ratio)
}
}
pub use crate::terms::sae::assignment::*;
pub use crate::terms::sae::basis::*;
pub use crate::terms::sae::frames::*;
#[derive(Debug, Clone)]
pub struct SaeManifoldAtom {
pub name: String,
pub basis_kind: SaeAtomBasisKind,
pub latent_dim: usize,
pub basis_values: Array2<f64>,
pub basis_jacobian: Array3<f64>,
pub decoder_coefficients: Array2<f64>,
pub smooth_penalty: Array2<f64>,
pub smooth_penalty_raw: Array2<f64>,
pub smooth_penalty_order: usize,
pub basis_evaluator: Option<Arc<dyn SaeBasisEvaluator>>,
pub basis_second_jet: Option<Arc<dyn SaeBasisSecondJet>>,
pub decoder_frame: Option<GrassmannFrame>,
pub homotopy_eta: f64,
pub chart_canonicalized: bool,
}
impl SaeManifoldAtom {
#[must_use = "build error must be handled"]
pub fn new(
name: impl Into<String>,
basis_kind: SaeAtomBasisKind,
latent_dim: usize,
basis_values: Array2<f64>,
basis_jacobian: Array3<f64>,
decoder_coefficients: Array2<f64>,
smooth_penalty: Array2<f64>,
) -> Result<Self, String> {
let n = basis_values.nrows();
let m = basis_values.ncols();
let p = decoder_coefficients.ncols();
if basis_jacobian.dim() != (n, m, latent_dim) {
return Err(format!(
"SaeManifoldAtom::new: basis_jacobian must be ({n}, {m}, {latent_dim}); got {:?}",
basis_jacobian.dim()
));
}
if decoder_coefficients.nrows() != m {
return Err(format!(
"SaeManifoldAtom::new: decoder rows {} must equal basis size {m}",
decoder_coefficients.nrows()
));
}
if smooth_penalty.dim() != (m, m) {
return Err(format!(
"SaeManifoldAtom::new: smooth penalty must be ({m}, {m}); got {:?}",
smooth_penalty.dim()
));
}
if p == 0 {
return Err("SaeManifoldAtom::new: decoder output dimension must be positive".into());
}
let smooth_penalty_order = smooth_penalty_nullity(&smooth_penalty)?;
let mut atom = Self {
name: name.into(),
basis_kind,
latent_dim,
basis_values,
decoder_coefficients,
smooth_penalty_raw: smooth_penalty.clone(),
smooth_penalty,
smooth_penalty_order,
basis_jacobian,
basis_evaluator: None,
basis_second_jet: None,
decoder_frame: None,
homotopy_eta: 1.0,
chart_canonicalized: false,
};
atom.refresh_intrinsic_smooth_penalty();
Ok(atom)
}
pub fn with_basis_evaluator(mut self, evaluator: Arc<dyn SaeBasisEvaluator>) -> Self {
self.basis_evaluator = Some(evaluator);
self.basis_second_jet = None;
self
}
pub fn with_basis_second_jet(mut self, evaluator: Arc<dyn SaeBasisSecondJet>) -> Self {
let base: Arc<dyn SaeBasisEvaluator> = evaluator.clone();
self.basis_evaluator = Some(base);
self.basis_second_jet = Some(evaluator);
self
}
pub fn refresh_basis(&mut self, coords: ArrayView2<'_, f64>) -> Result<(), String> {
let Some(evaluator) = self.basis_evaluator.as_ref() else {
return Ok(());
};
let (phi, jet) = if self.homotopy_eta == 1.0 {
evaluator.evaluate(coords)?
} else {
let evaluated = evaluator.evaluate_phi_eta(coords, self.homotopy_eta)?;
(evaluated.phi, evaluated.jet)
};
if phi.dim() != self.basis_values.dim() {
return Err(format!(
"SaeManifoldAtom::refresh_basis: evaluator returned Phi {:?}, expected {:?}",
phi.dim(),
self.basis_values.dim()
));
}
if jet.dim() != self.basis_jacobian.dim() {
return Err(format!(
"SaeManifoldAtom::refresh_basis: evaluator returned jet {:?}, expected {:?}",
jet.dim(),
self.basis_jacobian.dim()
));
}
self.basis_values = phi;
self.basis_jacobian = jet;
Ok(())
}
pub fn n_obs(&self) -> usize {
self.basis_values.nrows()
}
pub fn basis_size(&self) -> usize {
self.basis_values.ncols()
}
pub fn output_dim(&self) -> usize {
self.decoder_coefficients.ncols()
}
pub fn border_frame_rank(&self) -> usize {
match &self.decoder_frame {
Some(frame) => frame.rank(),
None => self.output_dim(),
}
}
pub fn border_coeff_count(&self) -> usize {
self.basis_size() * self.border_frame_rank()
}
pub fn frame_manifold_dimension(&self) -> usize {
match &self.decoder_frame {
Some(frame) => frame.manifold_dimension(),
None => 0,
}
}
pub fn decoder_numerical_rank(&self) -> Result<usize, String> {
let p = self.output_dim();
if p == 0 || self.basis_size() == 0 {
return Ok(0);
}
let (_u, sv, _vt) = self
.decoder_coefficients
.svd(false, false)
.map_err(|e| format!("SaeManifoldAtom::decoder_numerical_rank: SVD failed: {e}"))?;
let max_sv = sv.iter().copied().fold(0.0_f64, f64::max);
if !(max_sv > 0.0) {
return Ok(0);
}
let tol = SAE_FRAME_RANK_CUTOFF * max_sv;
Ok(sv.iter().filter(|&&v| v > tol).count())
}
pub fn decoder_frame_activation_rank(&self) -> Result<Option<usize>, String> {
let p = self.output_dim();
if p == 0 || self.basis_size() == 0 {
return Ok(None);
}
if p < SAE_FRAME_MIN_AUTO_OUTPUT_DIM {
return Ok(None);
}
let numerical_rank = self.decoder_numerical_rank()?;
let r = numerical_rank.max(1).min(p);
let shrink_ok = (r as f64) <= (p as f64) * (1.0 - SAE_FRAME_ACTIVATION_MARGIN);
if !shrink_ok || p.saturating_sub(r) == 0 {
return Ok(None);
}
Ok(Some(r))
}
pub fn maybe_activate_decoder_frame(&mut self) -> Result<Option<usize>, String> {
let Some(r) = self.decoder_frame_activation_rank()? else {
self.decoder_frame = None;
return Ok(None);
};
let p = self.output_dim();
let (_w, sv, vt_opt) = self.decoder_coefficients.svd(false, true).map_err(|e| {
format!("SaeManifoldAtom::maybe_activate_decoder_frame: SVD failed: {e}")
})?;
let vt = vt_opt.ok_or_else(|| {
"SaeManifoldAtom::maybe_activate_decoder_frame: SVD returned no right factor"
.to_string()
})?;
let available = vt.nrows();
let r_eff = r.min(available);
if r_eff == 0 || p.saturating_sub(r_eff) == 0 {
self.decoder_frame = None;
return Ok(None);
}
let mut frame = Array2::<f64>::zeros((p, r_eff));
for col in 0..r_eff {
for row in 0..p {
frame[[row, col]] = vt[[col, row]];
}
}
let mut gauge = Array1::<f64>::zeros(r_eff);
for i in 0..r_eff {
gauge[i] = sv.get(i).copied().unwrap_or(0.0);
}
self.decoder_frame = Some(GrassmannFrame::from_oriented(frame, gauge));
let u_proj = self
.decoder_frame
.as_ref()
.expect("frame just set")
.frame()
.to_owned();
let c_proj = self.decoder_coefficients.dot(&u_proj);
self.decoder_coefficients = c_proj.dot(&u_proj.t());
Ok(Some(r_eff))
}
pub fn deactivate_decoder_frame(&mut self) {
self.decoder_frame = None;
}
pub fn factored_coordinates(&self) -> Result<Option<Array2<f64>>, String> {
match &self.decoder_frame {
Some(frame) => Ok(Some(
frame.project_decoder(self.decoder_coefficients.view())?,
)),
None => Ok(None),
}
}
pub fn reconstruct_decoder_coefficients(
&self,
coords: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, String> {
let frame = self.decoder_frame.as_ref().ok_or_else(|| {
"SaeManifoldAtom::reconstruct_decoder_coefficients: no active frame".to_string()
})?;
frame.reconstruct_decoder(coords)
}
pub fn set_factored_coordinates(&mut self, coords: ArrayView2<'_, f64>) -> Result<(), String> {
let reconstructed = self.reconstruct_decoder_coefficients(coords)?;
if reconstructed.dim() != self.decoder_coefficients.dim() {
return Err(format!(
"SaeManifoldAtom::set_factored_coordinates: reconstructed decoder {:?} \
must match {:?}",
reconstructed.dim(),
self.decoder_coefficients.dim()
));
}
self.decoder_coefficients = reconstructed;
Ok(())
}
pub fn refresh_frame_from_cross_moment(
&mut self,
cross_moment: ArrayView2<'_, f64>,
) -> Result<(), String> {
if self.decoder_frame.is_none() {
return Err("SaeManifoldAtom::refresh_frame_from_cross_moment: no active frame".into());
}
let new_frame = GrassmannFrame::polar_update(cross_moment)?;
if new_frame.output_dim() != self.output_dim() {
return Err(format!(
"SaeManifoldAtom::refresh_frame_from_cross_moment: frame output dim {} \
must equal decoder output dim {}",
new_frame.output_dim(),
self.output_dim()
));
}
let coords = new_frame.project_decoder(self.decoder_coefficients.view())?;
self.decoder_coefficients = new_frame.reconstruct_decoder(coords.view())?;
self.decoder_frame = Some(new_frame);
Ok(())
}
pub fn decoded_row(&self, row: usize) -> Array1<f64> {
let p = self.output_dim();
let mut out = Array1::<f64>::zeros(p);
self.fill_decoded_row(row, out.as_slice_mut().expect("contiguous"));
out
}
pub fn fill_decoded_row(&self, row: usize, out: &mut [f64]) {
let p = self.output_dim();
let m = self.basis_size();
assert_eq!(out.len(), p);
for slot in out.iter_mut() {
*slot = 0.0;
}
for basis_col in 0..m {
let phi = self.basis_values[[row, basis_col]];
if phi == 0.0 {
continue;
}
for out_col in 0..p {
out[out_col] += phi * self.decoder_coefficients[[basis_col, out_col]];
}
}
}
pub fn decoded_derivative_row(&self, row: usize, latent_axis: usize) -> Array1<f64> {
let p = self.output_dim();
let mut out = Array1::<f64>::zeros(p);
self.fill_decoded_derivative_row(row, latent_axis, out.as_slice_mut().expect("contiguous"));
out
}
pub fn fill_decoded_derivative_row(&self, row: usize, latent_axis: usize, out: &mut [f64]) {
let p = self.output_dim();
let m = self.basis_size();
assert_eq!(out.len(), p);
for slot in out.iter_mut() {
*slot = 0.0;
}
for basis_col in 0..m {
let dphi = self.basis_jacobian[[row, basis_col, latent_axis]];
if dphi == 0.0 {
continue;
}
for out_col in 0..p {
out[out_col] += dphi * self.decoder_coefficients[[basis_col, out_col]];
}
}
}
pub fn refresh_intrinsic_smooth_penalty(&mut self) {
let m = self.basis_size();
if m == 0 || self.smooth_penalty_order == 0 || self.latent_dim != 1 {
self.smooth_penalty.assign(&self.smooth_penalty_raw);
return;
}
let n = self.n_obs();
let p = self.output_dim();
let beta = 0.5 - self.smooth_penalty_order as f64;
let mut act = vec![0.0_f64; m];
let mut num = vec![0.0_f64; m];
let mut deriv = vec![0.0_f64; p];
let hyperbolic = matches!(self.basis_kind, SaeAtomBasisKind::Poincare);
let linear_col = if hyperbolic && m >= 2 { Some(1usize) } else { None };
for row in 0..n {
self.fill_decoded_derivative_row(row, 0, &mut deriv);
let mut speed_sq = 0.0_f64;
for &d in deriv.iter() {
speed_sq += d * d;
}
if let Some(col) = linear_col {
let t = self.basis_values[[row, col]];
let lambda = 2.0 * t.cosh() * t.cosh();
if lambda.is_finite() && lambda > 0.0 {
speed_sq /= lambda * lambda;
}
}
for col in 0..m {
let phi = self.basis_values[[row, col]];
let w = phi * phi;
if w == 0.0 {
continue;
}
act[col] += w;
num[col] += w * speed_sq;
}
}
let mut speeds = vec![0.0_f64; m];
let mut log_acc = 0.0_f64;
let mut log_cnt = 0usize;
for col in 0..m {
let s = if act[col] > 0.0 {
num[col] / act[col]
} else {
0.0
};
speeds[col] = s;
if s > 0.0 && s.is_finite() {
log_acc += s.ln();
log_cnt += 1;
}
}
let center = if log_cnt > 0 {
(log_acc / log_cnt as f64).exp()
} else {
0.0
};
if !(center > 0.0 && center.is_finite()) {
self.smooth_penalty.assign(&self.smooth_penalty_raw);
return;
}
const RELATIVE_SPEED_FLOOR: f64 = 1.0e-6;
const RELATIVE_SPEED_CEIL: f64 = 1.0e6;
let mut root_w = vec![0.0_f64; m];
for col in 0..m {
let ratio = speeds[col] / center;
let ratio = if ratio.is_finite() {
ratio.clamp(RELATIVE_SPEED_FLOOR, RELATIVE_SPEED_CEIL)
} else {
RELATIVE_SPEED_CEIL
};
root_w[col] = ratio.powf(0.5 * beta);
}
for i in 0..m {
let ri = root_w[i];
for j in 0..m {
self.smooth_penalty[[i, j]] = ri * self.smooth_penalty_raw[[i, j]] * root_w[j];
}
}
}
}
fn smooth_penalty_nullity(s: &Array2<f64>) -> Result<usize, String> {
let m = s.ncols();
if m == 0 {
return Ok(0);
}
let mut sym = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
sym[[i, j]] = 0.5 * (s[[i, j]] + s[[j, i]]);
}
}
let (evals, _evecs) = sym
.eigh(Side::Lower)
.map_err(|e| format!("smooth_penalty_nullity: eigh failed: {e}"))?;
let max_eig = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v));
if !(max_eig > 0.0) {
return Ok(0);
}
let tol = SAE_MANIFOLD_SPECTRAL_RANK_CUTOFF * max_eig;
Ok(evals.iter().filter(|&&v| v <= tol).count())
}
#[derive(Debug, Clone)]
pub struct SaeManifoldRho {
pub log_lambda_sparse: f64,
pub log_lambda_smooth: f64,
pub log_ard: Vec<Array1<f64>>,
}
impl SaeManifoldRho {
#[must_use]
pub fn new(log_lambda_sparse: f64, log_lambda_smooth: f64, log_ard: Vec<Array1<f64>>) -> Self {
Self {
log_lambda_sparse,
log_lambda_smooth,
log_ard,
}
}
pub fn seed_scaled_by_dispersion(&self, dispersion: f64) -> Result<Self, String> {
self.seed_scaled_by_dispersion_with_sparse_policy(dispersion, true)
}
pub fn seed_scaled_by_dispersion_for_assignment(
&self,
dispersion: f64,
assignment_mode: AssignmentMode,
) -> Result<Self, String> {
let scale_sparse = !matches!(
assignment_mode,
AssignmentMode::IBPMap {
learnable_alpha: true,
..
}
);
self.seed_scaled_by_dispersion_with_sparse_policy(dispersion, scale_sparse)
}
fn seed_scaled_by_dispersion_with_sparse_policy(
&self,
dispersion: f64,
scale_sparse: bool,
) -> Result<Self, String> {
if !(dispersion.is_finite() && dispersion > 0.0) {
return Err(format!(
"SaeManifoldRho::seed_scaled_by_dispersion: dispersion must be finite and \
positive; got {dispersion}"
));
}
let shift = dispersion.ln();
let mut scaled = self.clone();
if scale_sparse {
scaled.log_lambda_sparse += shift;
}
scaled.log_lambda_smooth += shift;
for atom in &mut scaled.log_ard {
for value in atom.iter_mut() {
*value += shift;
}
}
Ok(scaled)
}
pub fn lambda_sparse(&self) -> f64 {
Self::stable_exp_strength(self.log_lambda_sparse)
}
pub fn lambda_smooth(&self) -> f64 {
Self::stable_exp_strength(self.log_lambda_smooth)
}
pub(crate) fn stable_exp_strength(log_strength: f64) -> f64 {
const MAX_LOG_STRENGTH: f64 = 700.0;
const MIN_LOG_STRENGTH: f64 = -700.0;
log_strength.clamp(MIN_LOG_STRENGTH, MAX_LOG_STRENGTH).exp()
}
pub fn to_flat(&self) -> Array1<f64> {
let ard_len: usize = self.log_ard.iter().map(|a| a.len()).sum();
let mut out = Array1::<f64>::zeros(2 + ard_len);
out[0] = self.log_lambda_sparse;
out[1] = self.log_lambda_smooth;
let mut cursor = 2usize;
for axis in &self.log_ard {
for &v in axis.iter() {
out[cursor] = v;
cursor += 1;
}
}
out
}
pub fn from_flat(&self, flat: ArrayView1<'_, f64>) -> SaeManifoldRho {
let ard_len: usize = self.log_ard.iter().map(|a| a.len()).sum();
assert_eq!(
flat.len(),
2 + ard_len,
"SaeManifoldRho::from_flat: flat length {} != 2 + Σ d_k = {}",
flat.len(),
2 + ard_len
);
let mut log_ard = Vec::with_capacity(self.log_ard.len());
let mut cursor = 2usize;
for axis in &self.log_ard {
let d = axis.len();
let mut block = Array1::<f64>::zeros(d);
for (j, slot) in block.iter_mut().enumerate() {
*slot = flat[cursor + j];
}
cursor += d;
log_ard.push(block);
}
SaeManifoldRho {
log_lambda_sparse: flat[0],
log_lambda_smooth: flat[1],
log_ard,
}
}
}
pub trait SaeKroneckerRow {
fn apply_jbeta(&self, row: usize, x_beta: &[f64], u_out: &mut [f64]);
fn scatter_jbeta_t(&self, row: usize, u: &[f64], y_beta: &mut [f64]);
fn apply_l(&self, row: usize, u: &[f64], w_out: &mut [f64]);
fn apply_l_t(&self, row: usize, v: &[f64], u_out: &mut [f64]);
}
#[derive(Debug, Clone)]
pub struct SaeKroneckerRows {
p: usize,
a_phi: Vec<Vec<(usize, f64)>>,
local_jac: Vec<Vec<f64>>,
}
impl SaeKroneckerRows {
pub fn new(p: usize, a_phi: Vec<Vec<(usize, f64)>>, local_jac: Vec<Vec<f64>>) -> Self {
assert_eq!(
a_phi.len(),
local_jac.len(),
"SaeKroneckerRows: a_phi rows ({}) != local_jac rows ({})",
a_phi.len(),
local_jac.len(),
);
Self {
p,
a_phi,
local_jac,
}
}
}
impl SaeKroneckerRow for SaeKroneckerRows {
fn apply_jbeta(&self, row: usize, x_beta: &[f64], u_out: &mut [f64]) {
for val in u_out.iter_mut() {
*val = 0.0;
}
for &(beta_base, phi) in &self.a_phi[row] {
if phi == 0.0 {
continue;
}
for j in 0..self.p {
u_out[j] += phi * x_beta[beta_base + j];
}
}
}
fn scatter_jbeta_t(&self, row: usize, u: &[f64], y_beta: &mut [f64]) {
for &(beta_base, phi) in &self.a_phi[row] {
if phi == 0.0 {
continue;
}
for j in 0..self.p {
y_beta[beta_base + j] += phi * u[j];
}
}
}
fn apply_l(&self, row: usize, u: &[f64], w_out: &mut [f64]) {
let jac = &self.local_jac[row];
let q_i = jac.len() / self.p;
for c in 0..q_i {
let mut acc = 0.0_f64;
for j in 0..self.p {
acc += jac[c * self.p + j] * u[j];
}
w_out[c] = acc;
}
}
fn apply_l_t(&self, row: usize, v: &[f64], u_out: &mut [f64]) {
let jac = &self.local_jac[row];
let q_i = jac.len() / self.p;
for c in 0..q_i {
let vc = v[c];
if vc == 0.0 {
continue;
}
for j in 0..self.p {
u_out[j] += jac[c * self.p + j] * vc;
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct SaeManifoldLoss {
pub data_fit: f64,
pub assignment_sparsity: f64,
pub smoothness: f64,
pub ard: f64,
pub evidence_gauge_deflated_directions: usize,
}
impl SaeManifoldLoss {
pub const fn total(&self) -> f64 {
self.data_fit + self.assignment_sparsity + self.smoothness + self.ard
}
pub const fn evidence_proxy(&self) -> f64 {
-self.total()
}
}
#[derive(Debug, Clone)]
pub struct SaeOuterRhoGradientComponents {
pub explicit: Array1<f64>,
pub logdet_trace: Array1<f64>,
pub occam: Array1<f64>,
pub third_order_correction: Array1<f64>,
pub third_order_correction_available: bool,
}
impl SaeOuterRhoGradientComponents {
#[must_use]
pub fn gradient_excluding_unavailable_correction(&self) -> Array1<f64> {
&(&self.explicit + &self.logdet_trace) + &self.occam
}
#[must_use]
pub fn gradient_with_available_correction(&self) -> Array1<f64> {
assert!(
self.third_order_correction_available,
"gradient_with_available_correction: third-order correction channel \
is not populated for this fit; use \
gradient_excluding_unavailable_correction() and account for the \
missing term explicitly"
);
&self.gradient_excluding_unavailable_correction() + &self.third_order_correction
}
}
#[derive(Debug, Clone)]
pub struct SaeArrowVector {
pub t: Array1<f64>,
pub beta: Array1<f64>,
}
pub(crate) struct DeflatedArrowSolver<'a> {
cache: &'a ArrowFactorCache,
gauge_basis: Vec<Array1<f64>>,
gauge_response_physical: Vec<Array1<f64>>,
woodbury_factor: Option<FaerCholeskyFactor>,
gauge_stiffness_recip: f64,
}
impl<'a> DeflatedArrowSolver<'a> {
fn plain(cache: &'a ArrowFactorCache) -> Self {
Self {
cache,
gauge_basis: Vec::new(),
gauge_response_physical: Vec::new(),
woodbury_factor: None,
gauge_stiffness_recip: 0.0,
}
}
fn from_orthonormal_gauges(
cache: &'a ArrowFactorCache,
gauge_basis: Vec<Array1<f64>>,
stiffness: f64,
) -> Result<Self, String> {
if gauge_basis.is_empty() {
return Ok(Self::plain(cache));
}
if !(stiffness.is_finite() && stiffness > 0.0) {
return Err(format!(
"DeflatedArrowSolver: gauge stiffness must be finite and positive; got {stiffness}"
));
}
let full_len = cache.delta_t_len() + cache.k;
let mut gauge_responses = Vec::with_capacity(gauge_basis.len());
for gauge in &gauge_basis {
if gauge.len() != full_len {
return Err(format!(
"DeflatedArrowSolver: gauge length {} != cache full length {full_len}",
gauge.len()
));
}
let (sol_t, sol_beta) = cache
.full_inverse_apply(
gauge.slice(s![..cache.delta_t_len()]),
gauge.slice(s![cache.delta_t_len()..]),
)
.map_err(|err| format!("DeflatedArrowSolver: gauge back-solve: {err}"))?;
gauge_responses.push(flatten_arrow_parts(sol_t.view(), sol_beta.view()));
}
let rank = gauge_basis.len();
let stiffness_recip = stiffness.recip();
let mut gauge_metric = Array2::<f64>::zeros((rank, rank));
let mut woodbury = Array2::<f64>::eye(rank);
for i in 0..rank {
woodbury[[i, i]] *= stiffness_recip;
for j in 0..rank {
let value = gauge_basis[i].dot(&gauge_responses[j]);
gauge_metric[[i, j]] = value;
woodbury[[i, j]] += value;
}
}
let woodbury_factor = woodbury
.cholesky(Side::Lower)
.map_err(|err| format!("DeflatedArrowSolver: gauge Woodbury factor failed: {err}"))?;
let mut gauge_response_physical = gauge_responses;
for j in 0..rank {
for i in 0..rank {
let coeff = gauge_metric[[i, j]];
for row in 0..full_len {
gauge_response_physical[j][row] -= coeff * gauge_basis[i][row];
}
}
}
Ok(Self {
cache,
gauge_basis,
gauge_response_physical,
woodbury_factor: Some(woodbury_factor),
gauge_stiffness_recip: stiffness_recip,
})
}
fn solve(
&self,
rhs_t: ArrayView1<'_, f64>,
rhs_beta: ArrayView1<'_, f64>,
) -> Result<SaeArrowVector, String> {
let (sol_t, sol_beta) = self
.cache
.full_inverse_apply(rhs_t, rhs_beta)
.map_err(|err| format!("DeflatedArrowSolver: full inverse: {err}"))?;
let Some(factor) = self.woodbury_factor.as_ref() else {
return Ok(SaeArrowVector {
t: sol_t,
beta: sol_beta,
});
};
let full_len = self.cache.delta_t_len() + self.cache.k;
let mut flat = flatten_arrow_parts(sol_t.view(), sol_beta.view());
if flat.len() != full_len {
return Err(format!(
"DeflatedArrowSolver: solution length {} != cache full length {full_len}",
flat.len()
));
}
let mut gauge_coeffs = Array1::<f64>::zeros(self.gauge_basis.len());
for (idx, gauge) in self.gauge_basis.iter().enumerate() {
gauge_coeffs[idx] = gauge.dot(&flat);
}
let weights = factor.solvevec(&gauge_coeffs);
for (gauge, &coeff) in self.gauge_basis.iter().zip(gauge_coeffs.iter()) {
for i in 0..flat.len() {
flat[i] -= gauge[i] * coeff;
}
}
for (response, &weight) in self.gauge_response_physical.iter().zip(weights.iter()) {
for i in 0..flat.len() {
flat[i] -= response[i] * weight;
}
}
for (gauge, &weight) in self.gauge_basis.iter().zip(weights.iter()) {
let coeff = self.gauge_stiffness_recip * weight;
for i in 0..flat.len() {
flat[i] += gauge[i] * coeff;
}
}
Ok(SaeArrowVector {
t: flat.slice(s![..self.cache.delta_t_len()]).to_owned(),
beta: flat.slice(s![self.cache.delta_t_len()..]).to_owned(),
})
}
fn latent_inverse_diagonal(&self) -> Result<Array1<f64>, String> {
if self.woodbury_factor.is_none() {
return self
.cache
.latent_block_inverse_diagonal()
.map_err(|err| format!("DeflatedArrowSolver: latent inverse diagonal: {err}"));
}
let total_t = self.cache.delta_t_len();
let mut out = Array1::<f64>::zeros(total_t);
let rhs_beta = Array1::<f64>::zeros(self.cache.k);
for idx in 0..total_t {
let mut rhs_t = Array1::<f64>::zeros(total_t);
rhs_t[idx] = 1.0;
let solved = self.solve(rhs_t.view(), rhs_beta.view())?;
out[idx] = solved.t[idx];
}
Ok(out)
}
}
fn flatten_arrow_parts(t: ArrayView1<'_, f64>, beta: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(t.len() + beta.len());
for i in 0..t.len() {
out[i] = t[i];
}
for i in 0..beta.len() {
out[t.len() + i] = beta[i];
}
out
}
fn apply_cached_arrow_hessian(
cache: &ArrowFactorCache,
v_t: ArrayView1<'_, f64>,
v_beta: ArrayView1<'_, f64>,
) -> Result<SaeArrowVector, String> {
let total_t = cache.delta_t_len();
if v_t.len() != total_t || v_beta.len() != cache.k {
return Err(format!(
"apply_cached_arrow_hessian: vector shapes (t={}, beta={}) != cache shapes \
(t={total_t}, beta={})",
v_t.len(),
v_beta.len(),
cache.k
));
}
let mut out_t = Array1::<f64>::zeros(total_t);
let mut out_beta = Array1::<f64>::zeros(cache.k);
for row in 0..cache.n_rows() {
let di = cache.row_dims[row];
let base = cache.row_offsets[row];
let row_v = v_t.slice(s![base..base + di]);
let factor = cache.undamped_factor(row);
let av = cholesky_factor_apply(factor, row_v);
for j in 0..di {
out_t[base + j] += av[j];
}
if cache.k > 0 {
let mut b_vbeta = Array1::<f64>::zeros(di);
if !cache.apply_htbeta_row(row, v_beta, &mut b_vbeta) {
return Err(format!(
"apply_cached_arrow_hessian: H_tβ^({row}) apply failed"
));
}
for j in 0..di {
out_t[base + j] += b_vbeta[j];
}
if !cache.apply_htbeta_row_transpose(row, row_v, &mut out_beta, None) {
return Err(format!(
"apply_cached_arrow_hessian: H_βt^({row}) apply failed"
));
}
}
}
if cache.k > 0 {
let Some(schur_factor) = cache.schur_factor.as_ref() else {
return Err(
"apply_cached_arrow_hessian: dense Schur factor is required for gauge probing"
.to_string(),
);
};
let schur_v = cholesky_factor_apply(schur_factor.view(), v_beta);
for i in 0..cache.k {
out_beta[i] += schur_v[i];
}
for row in 0..cache.n_rows() {
let di = cache.row_dims[row];
let mut b_vbeta = Array1::<f64>::zeros(di);
if !cache.apply_htbeta_row(row, v_beta, &mut b_vbeta) {
return Err(format!(
"apply_cached_arrow_hessian: H_tβ^({row}) Schur correction apply failed"
));
}
let a_inv_b_vbeta = cholesky_solve_vector(cache.undamped_factor(row), b_vbeta.view());
if !cache.apply_htbeta_row_transpose(row, a_inv_b_vbeta.view(), &mut out_beta, None) {
return Err(format!(
"apply_cached_arrow_hessian: H_βt^({row}) Schur correction apply failed"
));
}
}
}
Ok(SaeArrowVector {
t: out_t,
beta: out_beta,
})
}
fn cholesky_factor_apply(factor: ArrayView2<'_, f64>, vector: ArrayView1<'_, f64>) -> Array1<f64> {
let n = factor.nrows();
let mut lt_v = Array1::<f64>::zeros(n);
for row in 0..n {
let mut acc = 0.0_f64;
for col in row..n {
acc += factor[[col, row]] * vector[col];
}
lt_v[row] = acc;
}
let mut out = Array1::<f64>::zeros(n);
for row in 0..n {
let mut acc = 0.0_f64;
for col in 0..=row {
acc += factor[[row, col]] * lt_v[col];
}
out[row] = acc;
}
out
}
#[derive(Debug, Clone, Copy)]
enum SaeLocalRowVar {
Logit { atom: usize },
Coord { atom: usize, axis: usize },
}
#[derive(Debug, Clone)]
struct SaeBorderChannel {
atom: usize,
basis_col: usize,
index: usize,
output: Vec<f64>,
}
#[derive(Debug, Clone)]
struct SaeRowJets {
vars: Vec<SaeLocalRowVar>,
first: Vec<Vec<f64>>,
second: Vec<Vec<Vec<f64>>>,
beta: Vec<Vec<f64>>,
beta_deriv: Vec<Vec<Vec<f64>>>,
beta_l_deriv: Vec<Vec<Vec<f64>>>,
}
fn sae_dot(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
pub const SHAPE_BAND_MAX_POINTS: usize = 512;
pub const SAE_DECODER_COV_PAYLOAD_MAX_ENTRIES: usize = 1 << 24;
#[derive(Debug, Clone)]
pub struct SaeAtomShapeUncertainty {
pub decoder_covariance: Option<Array2<f64>>,
pub band_coords: Array2<f64>,
pub band_mean: Array2<f64>,
pub band_sd: Array2<f64>,
}
#[derive(Debug, Clone)]
pub struct SaeShapeUncertainty {
pub dispersion: f64,
pub atoms: Vec<SaeAtomShapeUncertainty>,
}
#[derive(Debug, Clone)]
pub struct SaeRowLayout {
pub active_atoms: Vec<Vec<usize>>,
pub coord_starts: Vec<Vec<usize>>,
pub coord_offsets_full: Vec<usize>,
pub coord_dims: Vec<usize>,
}
impl SaeRowLayout {
fn from_jumprelu(
n: usize,
k_atoms: usize,
threshold: f64,
temperature: f64,
logits: &Array2<f64>,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let mut per_row = Vec::with_capacity(n);
for row in 0..n {
let row_logits = logits.row(row);
let active: Vec<usize> = (0..k_atoms)
.filter(|&k| jumprelu_in_optimization_band(row_logits[k], threshold, temperature))
.collect();
per_row.push(active);
}
Self::from_active_atoms(per_row, coord_dims, coord_offsets_full)
}
fn from_dense_weights(
assignments: &[Array1<f64>],
k_active_cap: usize,
cutoff: f64,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let cap = k_active_cap.max(1);
let mut per_row = Vec::with_capacity(assignments.len());
for a in assignments {
let k = a.len();
let mut idx: Vec<usize> = (0..k).collect();
idx.sort_by(|&i, &j| {
a[j].abs()
.partial_cmp(&a[i].abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut active: Vec<usize> = idx
.iter()
.copied()
.take(cap)
.filter(|&k_idx| a[k_idx].abs() > cutoff)
.collect();
if active.is_empty() {
if let Some(&top) = idx.first() {
active.push(top);
}
}
active.sort_unstable();
per_row.push(active);
}
Self::from_active_atoms(per_row, coord_dims, coord_offsets_full)
}
fn from_active_atoms(
active_atoms: Vec<Vec<usize>>,
coord_dims: Vec<usize>,
coord_offsets_full: Vec<usize>,
) -> Self {
let mut coord_starts_all = Vec::with_capacity(active_atoms.len());
for active in &active_atoms {
let mut starts = Vec::with_capacity(active.len());
let mut cursor = active.len();
for &k in active {
starts.push(cursor);
cursor += coord_dims[k];
}
coord_starts_all.push(starts);
}
Self {
active_atoms,
coord_starts: coord_starts_all,
coord_offsets_full,
coord_dims,
}
}
pub fn row_q_active(&self, row: usize) -> usize {
let active = &self.active_atoms[row];
let coord_sum: usize = active.iter().map(|&k| self.coord_dims[k]).sum();
active.len() + coord_sum
}
pub fn expand_row(&self, row: usize, delta_t_row: &[f64], out: &mut [f64]) {
for v in out.iter_mut() {
*v = 0.0;
}
let active = &self.active_atoms[row];
let starts = &self.coord_starts[row];
for (j, &k) in active.iter().enumerate() {
out[k] = delta_t_row[j];
let d = self.coord_dims[k];
let full_off = self.coord_offsets_full[k];
for axis in 0..d {
out[full_off + axis] = delta_t_row[starts[j] + axis];
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum GlobalOptimalityVerdict {
CertifiedGlobal { margin: f64 },
Uncertified { margin: f64 },
}
impl GlobalOptimalityVerdict {
pub fn margin(&self) -> f64 {
match self {
Self::CertifiedGlobal { margin } | Self::Uncertified { margin } => *margin,
}
}
pub fn is_certified(&self) -> bool {
matches!(self, Self::CertifiedGlobal { .. })
}
}
pub const SAE_CERT_CURVATURE_CONSTANT: f64 = 1.0;
pub const SAE_CERT_INCOHERENCE_BUDGET: f64 = 0.125;
pub fn curved_dictionary_global_optimality_verdict(
mu_hat: f64,
kappa_max: f64,
activity_floor: f64,
snr_proxy: f64,
k_atoms: usize,
) -> GlobalOptimalityVerdict {
if !mu_hat.is_finite()
|| !kappa_max.is_finite()
|| !activity_floor.is_finite()
|| !snr_proxy.is_finite()
|| k_atoms == 0
{
return GlobalOptimalityVerdict::Uncertified {
margin: f64::NEG_INFINITY,
};
}
let curvature_factor = 1.0 - SAE_CERT_CURVATURE_CONSTANT * kappa_max.max(0.0);
let snr_factor = 1.0 - 1.0 / snr_proxy;
if curvature_factor <= 0.0 || snr_factor <= 0.0 {
return GlobalOptimalityVerdict::Uncertified {
margin: f64::NEG_INFINITY,
};
}
let a = activity_floor.max(0.0);
let budget =
SAE_CERT_INCOHERENCE_BUDGET * a * a * snr_factor * curvature_factor / k_atoms as f64;
let margin = budget - mu_hat;
if margin > 0.0 {
GlobalOptimalityVerdict::CertifiedGlobal { margin }
} else {
GlobalOptimalityVerdict::Uncertified { margin }
}
}
#[derive(Clone, Debug)]
pub struct CertificateInputs {
pub mu_hat: f64,
pub per_atom_kappa_hat: Vec<f64>,
pub per_atom_mean_activity: Vec<f64>,
pub per_atom_peak_activity: Vec<f64>,
pub mean_activity_floor: f64,
pub peak_activity_floor: f64,
pub snr_proxy: f64,
pub dispersion: f64,
pub global_optimality: GlobalOptimalityVerdict,
pub note: String,
}
#[derive(Clone, Debug)]
pub struct SaeManifoldFitDiagnostics {
pub atom_two_lens: crate::inference::atom_lens::AtomTwoLensReport,
pub residual_gauge: crate::sae_identifiability::ResidualGaugeReport,
pub incoherence_report: Option<CertificateInputs>,
pub atom_inference: Vec<crate::sae_identifiability::AtomInferenceReport>,
}
#[derive(Clone, Debug)]
pub struct SaeTrustDiagnostics {
pub atom_trust: Vec<f64>,
pub atoms: Vec<SaeAtomTrustDiagnostics>,
}
#[derive(Clone, Debug)]
pub struct SaeAtomTrustDiagnostics {
pub trust_score: f64,
pub sigma_min_tangent: f64,
pub sigma_max_tangent: f64,
pub tangent_condition_score: f64,
pub coverage: f64,
pub activation_frequency: f64,
pub untyped: bool,
pub active_token_count: usize,
}
pub fn dictionary_incoherence_report(term: &SaeManifoldTerm) -> Result<CertificateInputs, String> {
let dispersion = term.certificate_dispersion.ok_or_else(|| {
"dictionary_incoherence_report: fitted reconstruction dispersion is unavailable".to_string()
})?;
dictionary_incoherence_report_with_dispersion(term, dispersion)
}
pub fn dictionary_incoherence_report_with_dispersion(
term: &SaeManifoldTerm,
dispersion: f64,
) -> Result<CertificateInputs, String> {
if !dispersion.is_finite() || dispersion <= 0.0 {
return Err(format!(
"dictionary_incoherence_report: dispersion must be finite and positive, got {dispersion}"
));
}
let mu_hat = dictionary_frame_incoherence(term)?;
let per_atom_kappa_hat = term
.atoms
.iter()
.enumerate()
.map(|(atom_idx, _)| atom_curvature_bound(term, atom_idx))
.collect::<Result<Vec<_>, _>>()?;
let assignments = term.assignment.assignments();
let n = assignments.nrows();
let k_atoms = assignments.ncols();
let mut per_atom_mean_activity = Vec::with_capacity(k_atoms);
let mut per_atom_peak_activity = Vec::with_capacity(k_atoms);
for atom_idx in 0..k_atoms {
let mut sum = 0.0_f64;
let mut peak = 0.0_f64;
for row in 0..n {
let value = assignments[[row, atom_idx]];
sum += value;
peak = peak.max(value);
}
per_atom_mean_activity.push(if n > 0 { sum / n as f64 } else { 0.0 });
per_atom_peak_activity.push(peak);
}
let mean_activity_floor = per_atom_mean_activity
.iter()
.copied()
.fold(f64::INFINITY, f64::min);
let peak_activity_floor = per_atom_peak_activity
.iter()
.copied()
.fold(f64::INFINITY, f64::min);
let fitted = term.fitted();
let signal_power = if fitted.is_empty() {
0.0
} else {
fitted.iter().map(|v| v * v).sum::<f64>() / fitted.len() as f64
};
let mean_activity_floor = if mean_activity_floor.is_finite() {
mean_activity_floor
} else {
0.0
};
let peak_activity_floor = if peak_activity_floor.is_finite() {
peak_activity_floor
} else {
0.0
};
let snr_proxy = signal_power / dispersion;
let kappa_max = per_atom_kappa_hat.iter().copied().fold(0.0_f64, f64::max);
let global_optimality = curved_dictionary_global_optimality_verdict(
mu_hat,
kappa_max,
peak_activity_floor,
snr_proxy,
k_atoms,
);
let note = match global_optimality {
GlobalOptimalityVerdict::CertifiedGlobal { margin } => format!(
"global optimality CERTIFIED up to the residual gauge group \
(margin {margin:.3e}); μ̂={mu_hat:.3e}, κ̂_max={kappa_max:.3e}, \
a_floor={peak_activity_floor:.3e}, SNR={snr_proxy:.3e}"
),
GlobalOptimalityVerdict::Uncertified { margin } => format!(
"global optimality UNCERTIFIED (margin {margin:.3e}; cannot decide — \
multistart/homotopy genuinely needed); μ̂={mu_hat:.3e}, \
κ̂_max={kappa_max:.3e}, a_floor={peak_activity_floor:.3e}, \
SNR={snr_proxy:.3e}"
),
};
Ok(CertificateInputs {
mu_hat,
per_atom_kappa_hat,
per_atom_mean_activity,
per_atom_peak_activity,
mean_activity_floor,
peak_activity_floor,
snr_proxy,
dispersion,
global_optimality,
note,
})
}
fn dictionary_frame_incoherence(term: &SaeManifoldTerm) -> Result<f64, String> {
let frames = (0..term.k_atoms())
.map(|atom_idx| certificate_output_frame(term, atom_idx))
.collect::<Result<Vec<_>, _>>()?;
let mut mu = 0.0_f64;
for j in 0..frames.len() {
for k in (j + 1)..frames.len() {
if frames[j].ncols() == 0 || frames[k].ncols() == 0 {
continue;
}
let overlap = fast_atb(&frames[j], &frames[k]);
let (_u, s, _vt) = overlap.svd(false, false).map_err(|e| {
format!("dictionary_frame_incoherence: SVD failed for atom pair ({j}, {k}): {e}")
})?;
let pair = s.iter().copied().fold(0.0_f64, f64::max);
mu = mu.max(pair);
}
}
Ok(mu)
}
fn certificate_output_frame(
term: &SaeManifoldTerm,
atom_idx: usize,
) -> Result<Array2<f64>, String> {
let atom = &term.atoms[atom_idx];
if atom.decoder_frame.is_some() {
return Ok(term.frame_output_matrix(atom_idx));
}
let p = atom.output_dim();
let (_u, s, vt_opt) = atom
.decoder_coefficients
.svd(false, true)
.map_err(|e| format!("certificate_output_frame: SVD failed for atom {atom_idx}: {e}"))?;
let max_sv = s.iter().copied().fold(0.0_f64, f64::max);
if !(max_sv > 0.0) {
return Ok(Array2::<f64>::zeros((p, 0)));
}
let tol = SAE_FRAME_RANK_CUTOFF * max_sv;
let rank = s.iter().filter(|&&value| value > tol).count();
let vt = vt_opt.ok_or_else(|| {
format!("certificate_output_frame: SVD returned no right factor for atom {atom_idx}")
})?;
let rank = rank.min(vt.nrows());
let mut frame = Array2::<f64>::zeros((p, rank));
for col in 0..rank {
for row in 0..p {
frame[[row, col]] = vt[[col, row]];
}
}
Ok(frame)
}
fn atom_curvature_bound(term: &SaeManifoldTerm, atom_idx: usize) -> Result<f64, String> {
let atom = &term.atoms[atom_idx];
let coords = term.assignment.coords[atom_idx].as_matrix();
let second = atom
.basis_evaluator
.as_ref()
.and_then(|evaluator| evaluator.second_jet_dyn(coords.view()))
.ok_or_else(|| {
format!(
"atom_curvature_bound: atom {atom_idx} has no analytic second jet; cannot compute kappa_hat"
)
})?
.map_err(|e| format!("atom_curvature_bound: atom {atom_idx} second jet failed: {e}"))?;
atom_curvature_bound_with_decoder(
atom,
atom_idx,
second.view(),
atom.decoder_coefficients.view(),
)
}
fn atom_curvature_bound_with_decoder(
atom: &SaeManifoldAtom,
atom_idx: usize,
second: ArrayView4<'_, f64>,
decoder: ArrayView2<'_, f64>,
) -> Result<f64, String> {
let n = atom.n_obs();
let m = atom.basis_size();
let d = atom.latent_dim;
let p = atom.output_dim();
if second.dim() != (n, m, d, d) {
return Err(format!(
"atom_curvature_bound: atom {atom_idx} second jet shape {:?} must be ({n}, {m}, {d}, {d})",
second.dim()
));
}
if decoder.dim() != (m, p) {
return Err(format!(
"atom_curvature_bound: atom {atom_idx} decoder shape {:?} must be ({m}, {p})",
decoder.dim()
));
}
let mut max_kappa = 0.0_f64;
let mut tangent = Array2::<f64>::zeros((p, d));
let mut second_vec = vec![0.0_f64; p];
for row in 0..n {
tangent.fill(0.0);
for basis_col in 0..m {
for axis in 0..d {
let dphi = atom.basis_jacobian[[row, basis_col, axis]];
if dphi == 0.0 {
continue;
}
for out in 0..p {
tangent[[out, axis]] += dphi * decoder[[basis_col, out]];
}
}
}
let tangent_rank = tangent_frame_rank(tangent.view())?;
let tangent_scale = tangent_rank.0;
let q = tangent_rank.1;
for axis_a in 0..d {
for axis_b in 0..d {
second_vec.fill(0.0);
for basis_col in 0..m {
let h = second[[row, basis_col, axis_a, axis_b]];
if h == 0.0 {
continue;
}
for out in 0..p {
second_vec[out] += h * decoder[[basis_col, out]];
}
}
let perp_norm = projected_perp_norm(&second_vec, q.view());
if tangent_scale > 0.0 {
max_kappa = max_kappa.max(perp_norm / tangent_scale);
} else if perp_norm > 0.0 {
return Ok(f64::INFINITY);
}
}
}
}
Ok(max_kappa)
}
fn tangent_frame_rank(tangent: ArrayView2<'_, f64>) -> Result<(f64, Array2<f64>), String> {
let p = tangent.nrows();
let d = tangent.ncols();
if p == 0 || d == 0 {
return Ok((0.0, Array2::<f64>::zeros((p, 0))));
}
let (u_opt, s, _vt) = tangent
.to_owned()
.svd(true, false)
.map_err(|e| format!("tangent_frame_rank: SVD failed: {e}"))?;
let max_sv = s.iter().copied().fold(0.0_f64, f64::max);
if !(max_sv > 0.0) {
return Ok((0.0, Array2::<f64>::zeros((p, 0))));
}
let tol = SAE_FRAME_RANK_CUTOFF * max_sv;
let rank = s.iter().filter(|&&value| value > tol).count();
let min_positive = s
.iter()
.copied()
.filter(|value| *value > tol)
.fold(f64::INFINITY, f64::min);
let u = u_opt.ok_or_else(|| "tangent_frame_rank: SVD returned no U".to_string())?;
let rank = rank.min(u.ncols());
let mut q = Array2::<f64>::zeros((p, rank));
for col in 0..rank {
for row in 0..p {
q[[row, col]] = u[[row, col]];
}
}
Ok((min_positive * min_positive, q))
}
fn projected_perp_norm(vector: &[f64], tangent_frame: ArrayView2<'_, f64>) -> f64 {
let mut residual = vector.to_vec();
for axis in 0..tangent_frame.ncols() {
let mut coeff = 0.0_f64;
for out in 0..tangent_frame.nrows() {
coeff += tangent_frame[[out, axis]] * vector[out];
}
if coeff == 0.0 {
continue;
}
for out in 0..tangent_frame.nrows() {
residual[out] -= coeff * tangent_frame[[out, axis]];
}
}
residual.iter().map(|v| v * v).sum::<f64>().sqrt()
}
#[derive(Debug)]
pub struct SaeManifoldTerm {
pub atoms: Vec<SaeManifoldAtom>,
pub assignment: SaeAssignment,
temperature_schedule: Option<GumbelTemperatureSchedule>,
last_row_layout: Option<SaeRowLayout>,
row_metric: Option<crate::inference::row_metric::RowMetric>,
collapse_events: Vec<CollapseEvent>,
row_loss_weights: Option<Vec<f64>>,
last_frames_active: bool,
border_hbb_workspace: Array2<f64>,
certificate_dispersion: Option<f64>,
curvature_walk_report: Option<CurvatureWalkReport>,
expected_evidence_gauge_deflated_directions: Option<usize>,
hybrid_split_report: Option<crate::terms::sae::hybrid_split::SaeHybridSplitReport>,
atom_inner_fits: Option<Vec<Option<crate::sae_identifiability::AtomInnerFit>>>,
decoder_data_null_projectors: Vec<Option<Array2<f64>>>,
}
impl Clone for SaeManifoldTerm {
fn clone(&self) -> Self {
Self {
atoms: self.atoms.clone(),
assignment: self.assignment.clone(),
temperature_schedule: self.temperature_schedule.clone(),
last_row_layout: self.last_row_layout.clone(),
row_metric: self.row_metric.clone(),
collapse_events: self.collapse_events.clone(),
row_loss_weights: self.row_loss_weights.clone(),
last_frames_active: self.last_frames_active,
border_hbb_workspace: Array2::<f64>::zeros((0, 0)),
certificate_dispersion: self.certificate_dispersion,
curvature_walk_report: self.curvature_walk_report.clone(),
expected_evidence_gauge_deflated_directions: self
.expected_evidence_gauge_deflated_directions,
hybrid_split_report: self.hybrid_split_report.clone(),
atom_inner_fits: self.atom_inner_fits.clone(),
decoder_data_null_projectors: self.decoder_data_null_projectors.clone(),
}
}
}
#[derive(Debug)]
struct SaeManifoldMutableState {
atoms: Vec<(Array2<f64>, Array3<f64>, Array2<f64>, Array2<f64>)>,
logits: Array2<f64>,
coords: Vec<LatentCoordValues>,
last_row_layout: Option<SaeRowLayout>,
}