use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::linalg::faer_ndarray::FaerEigh;
use crate::terms::sae_candidate_index::{
AtomFrameSketch, SaeCandidateIndex, auto_candidate_budget,
};
use crate::terms::sae_manifold::{
AffineCoordinateEvaluator, CylinderHarmonicEvaluator, DuchonCoordinateEvaluator,
EuclideanPatchEvaluator, PeriodicHarmonicEvaluator, SaeBasisEvaluator, SaeManifoldAtom,
SphereChartEvaluator, TorusHarmonicEvaluator,
};
use faer::Side;
pub const KANTOROVICH_THRESHOLD: f64 = 0.5;
pub(crate) const ENCODE_BATCH_PARALLEL_ROW_MIN: usize = 256;
#[derive(Debug, Clone)]
pub struct ChartRegion {
pub center: Array1<f64>,
pub radius: f64,
pub exclusion_r_min: Option<f64>,
pub radial_r_max: Option<f64>,
}
impl ChartRegion {
pub fn new(center: Array1<f64>, radius: f64) -> Self {
Self {
center,
radius,
exclusion_r_min: None,
radial_r_max: None,
}
}
pub fn with_radial_bounds(mut self, r_min: f64, r_max: f64) -> Self {
self.exclusion_r_min = Some(r_min);
self.radial_r_max = Some(r_max);
self
}
pub(crate) fn assert_valid(&self) {
assert!(
self.radius.is_finite()
&& self.radius >= 0.0
&& self.center.iter().all(|c| c.is_finite()),
"ChartRegion must have a finite center and a finite non-negative radius"
);
}
}
pub trait BasisHessianLipschitz {
fn value_sup(&self, chart: &ChartRegion) -> f64;
fn jacobian_sup(&self, chart: &ChartRegion) -> f64;
fn hessian_sup(&self, chart: &ChartRegion) -> f64;
fn third_sup(&self, chart: &ChartRegion) -> f64;
}
pub(crate) fn harmonic_jet_sup(num_basis: usize, order: u32) -> f64 {
let top_harmonic = num_basis.saturating_sub(1) / 2;
let omega = std::f64::consts::TAU * top_harmonic as f64;
omega.powi(order as i32)
}
impl BasisHessianLipschitz for PeriodicHarmonicEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
1.0
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
harmonic_jet_sup(self.num_basis, 1)
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
harmonic_jet_sup(self.num_basis, 2)
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
harmonic_jet_sup(self.num_basis, 3)
}
}
impl BasisHessianLipschitz for TorusHarmonicEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
1.0
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
torus_jet_sup(self.num_harmonics, self.latent_dim, 1)
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
torus_jet_sup(self.num_harmonics, self.latent_dim, 2)
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
torus_jet_sup(self.num_harmonics, self.latent_dim, 3)
}
}
pub(crate) fn torus_jet_sup(num_harmonics: usize, latent_dim: usize, order: u32) -> f64 {
let omega = std::f64::consts::TAU * num_harmonics as f64;
omega.powi(order as i32) * (latent_dim as f64).powi(order as i32)
}
impl BasisHessianLipschitz for SphereChartEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
1.0
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
4.0
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
16.0
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
64.0
}
}
impl BasisHessianLipschitz for AffineCoordinateEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
let center_norm = chart.center.dot(&chart.center).sqrt();
1.0 + center_norm + chart.radius
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
1.0
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
0.0
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
0.0
}
}
impl BasisHessianLipschitz for EuclideanPatchEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
let rho = patch_rho(chart);
let d = self.max_degree as i32;
rho.powi(d).max(1.0)
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
patch_jet_sup(self.latent_dim, self.max_degree, chart, 1)
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
patch_jet_sup(self.latent_dim, self.max_degree, chart, 2)
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
patch_jet_sup(self.latent_dim, self.max_degree, chart, 3)
}
}
impl BasisHessianLipschitz for CylinderHarmonicEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 0)
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 1)
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 2)
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 3)
}
}
pub(crate) fn cylinder_jet_sup(
circle_harmonics: usize,
line_degree: usize,
chart: &ChartRegion,
order: u32,
) -> f64 {
let omega = std::f64::consts::TAU * circle_harmonics as f64;
let big_d = line_degree as f64;
let rho = patch_rho(chart);
let mut best = 0.0_f64;
for k0 in 0..=order {
let k1 = order - k0;
let circle = if k0 == 0 { 1.0 } else { omega.powi(k0 as i32) };
let line = if k1 == 0 {
rho.powi(line_degree as i32).max(1.0)
} else {
let residual = line_degree.saturating_sub(k1 as usize) as i32;
big_d.powi(k1 as i32) * rho.powi(residual)
};
best = best.max(circle * line);
}
best
}
pub(crate) fn patch_rho(chart: &ChartRegion) -> f64 {
let center_inf = chart
.center
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()));
center_inf + chart.radius
}
pub(crate) fn patch_jet_sup(
latent_dim: usize,
max_degree: usize,
chart: &ChartRegion,
order: u32,
) -> f64 {
let d = latent_dim as f64;
let big_d = max_degree as f64;
let rho = patch_rho(chart);
let residual_degree = max_degree.saturating_sub(order as usize) as i32;
d.powi(order as i32) * big_d.powi(order as i32) * rho.powi(residual_degree)
}
impl BasisHessianLipschitz for DuchonCoordinateEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
let r_max = chart.radial_r_max.unwrap_or(chart.radius);
let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 0);
(r_max.powi(3)).max(poly)
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
let r_max = chart.radial_r_max.unwrap_or(chart.radius);
let kernel = 3.0 * r_max * r_max;
let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 1);
kernel.max(poly)
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
let r_max = chart.radial_r_max.unwrap_or(chart.radius);
let r_min = chart
.exclusion_r_min
.unwrap_or(chart.radius)
.max(f64::MIN_POSITIVE);
let kernel = 6.0 * r_max + 3.0 * r_max * r_max / r_min;
let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 2);
kernel.max(poly)
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
let r_max = chart.radial_r_max.unwrap_or(chart.radius);
let r_min = chart
.exclusion_r_min
.unwrap_or(chart.radius)
.max(f64::MIN_POSITIVE);
let kernel = 6.0 + 18.0 * r_max / r_min + 9.0 * r_max * r_max / (r_min * r_min);
let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 3);
kernel.max(poly)
}
}
trait DuchonOrderDegree {
fn order_degree(&self) -> usize;
}
impl DuchonOrderDegree for DuchonCoordinateEvaluator {
fn order_degree(&self) -> usize {
match self.order {
crate::basis::DuchonNullspaceOrder::Zero => 0,
crate::basis::DuchonNullspaceOrder::Linear => 1,
crate::basis::DuchonNullspaceOrder::Degree(d) => d,
}
}
}
pub(crate) fn duchon_poly_jet_sup(
latent_dim: usize,
order_degree: usize,
chart: &ChartRegion,
order: u32,
) -> f64 {
if order_degree == 0 {
return if order == 0 { 1.0 } else { 0.0 };
}
patch_jet_sup(latent_dim, order_degree, chart, order)
}
pub(crate) fn decoder_row_norm_sum(decoder: ArrayView2<'_, f64>) -> f64 {
let mut acc = 0.0;
for row in decoder.rows() {
acc += row.dot(&row).sqrt();
}
acc
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct ReconstructionJetSups {
pub(crate) value: f64,
pub(crate) jacobian: f64,
pub(crate) hessian: f64,
pub(crate) third: f64,
}
pub(crate) fn pair_trig_decoder_sup(
sin_row: ArrayView1<'_, f64>,
cos_row: ArrayView1<'_, f64>,
) -> f64 {
let aa = sin_row.dot(&sin_row);
let bb = cos_row.dot(&cos_row);
let ab = sin_row.dot(&cos_row);
let trace = aa + bb;
let disc = ((aa - bb) * (aa - bb) + 4.0 * ab * ab).sqrt();
(0.5 * (trace + disc)).sqrt()
}
pub(crate) fn periodic_reconstruction_jet_sups(
decoder: ArrayView2<'_, f64>,
) -> ReconstructionJetSups {
let mut value = 0.0;
let mut jacobian = 0.0;
let mut hessian = 0.0;
let mut third = 0.0;
if decoder.nrows() > 0 {
value += decoder.row(0).dot(&decoder.row(0)).sqrt();
}
let harmonics = decoder.nrows().saturating_sub(1) / 2;
for h in 1..=harmonics {
let sin_idx = 2 * h - 1;
let cos_idx = 2 * h;
let amp = pair_trig_decoder_sup(decoder.row(sin_idx), decoder.row(cos_idx));
let omega = std::f64::consts::TAU * h as f64;
value += amp;
jacobian += omega * amp;
hessian += omega.powi(2) * amp;
third += omega.powi(3) * amp;
}
for row in (1 + 2 * harmonics)..decoder.nrows() {
let amp = decoder.row(row).dot(&decoder.row(row)).sqrt();
value += amp;
let omega = std::f64::consts::TAU * harmonics.max(1) as f64;
jacobian += omega * amp;
hessian += omega.powi(2) * amp;
third += omega.powi(3) * amp;
}
ReconstructionJetSups {
value,
jacobian,
hessian,
third,
}
}
pub(crate) fn reconstruction_jet_sups(
atom: &SaeManifoldAtom,
sups: JetSups,
) -> ReconstructionJetSups {
if matches!(
atom.basis_kind,
crate::terms::sae_manifold::SaeAtomBasisKind::Periodic
) {
periodic_reconstruction_jet_sups(atom.decoder_coefficients.view())
} else {
let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
ReconstructionJetSups {
value: decoder_norm_sum * sups.value,
jacobian: decoder_norm_sum * sups.jacobian,
hessian: decoder_norm_sum * sups.hessian,
third: decoder_norm_sum * sups.third,
}
}
}
pub(crate) fn hessian_lipschitz_constant(
recon_sups: ReconstructionJetSups,
amplitude: f64,
target_norm: f64,
prior_lipschitz: f64,
) -> f64 {
let z = amplitude.abs();
let m_jac = z * recon_sups.jacobian;
let m_hess = z * recon_sups.hessian;
let m_third = z * recon_sups.third;
let recon_value = z * recon_sups.value;
let r_norm = target_norm + recon_value;
3.0 * m_jac * m_hess + r_norm * m_third + prior_lipschitz
}
#[derive(Debug, Clone)]
pub struct CertifiedChart {
pub region: ChartRegion,
pub lipschitz: f64,
pub beta_center: f64,
pub certified_radius: f64,
pub amortized_jacobian: Option<Array2<f64>>,
pub recon_center: Array1<f64>,
}
#[derive(Debug, Clone)]
pub struct AtomEncodeAtlas {
pub atom_index: usize,
pub latent_dim: usize,
pub decoder_norm_sum: f64,
pub charts: Vec<CertifiedChart>,
}
#[derive(Debug, Clone)]
pub struct EncodeResult {
pub coords: Array2<f64>,
pub certified: Vec<bool>,
pub encode_uncertified_count: usize,
}
impl EncodeResult {
pub(crate) fn from_rows(coords: Array2<f64>, certified: Vec<bool>) -> Self {
let encode_uncertified_count = certified.iter().filter(|c| !**c).count();
Self {
coords,
certified,
encode_uncertified_count,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct RowCertificate {
pub beta: f64,
pub eta: f64,
pub lipschitz: f64,
pub h: f64,
}
impl RowCertificate {
pub fn certified(&self) -> bool {
self.h.is_finite() && self.h <= KANTOROVICH_THRESHOLD
}
}
pub(crate) const SAE_CYLINDER_LINE_DEGREE: usize = 2;
pub(crate) fn family_jet_sups(
atom: &SaeManifoldAtom,
chart: &ChartRegion,
) -> Result<JetSups, String> {
use crate::terms::sae_manifold::SaeAtomBasisKind::*;
let m = atom.basis_size();
let d = atom.latent_dim;
let sups = match &atom.basis_kind {
Periodic => {
let ev = PeriodicHarmonicEvaluator::new(m)?;
JetSups::from_family(&ev, chart)
}
Torus => {
let axis_m = integer_root(m, d.max(1));
let num_harmonics = axis_m.saturating_sub(1) / 2;
let ev = TorusHarmonicEvaluator::new(d, num_harmonics.max(1))?;
JetSups::from_family(&ev, chart)
}
Sphere => {
let ev = SphereChartEvaluator;
JetSups::from_family(&ev, chart)
}
Cylinder => {
let ml = SAE_CYLINDER_LINE_DEGREE + 1;
if d != 2 || ml == 0 || m % ml != 0 {
return Err(format!(
"EncodeAtlas: Cylinder atom requires latent_dim == 2 and width divisible by {ml}; got dim={d}, m={m}"
));
}
let axis_mc = m / ml;
let h = axis_mc.saturating_sub(1) / 2;
let ev = CylinderHarmonicEvaluator::new(h.max(1), SAE_CYLINDER_LINE_DEGREE)?;
JetSups::from_family(&ev, chart)
}
EuclideanPatch | Poincare => {
let degree = euclidean_patch_degree(d, m);
let ev = EuclideanPatchEvaluator::new(d, degree)?;
JetSups::from_family(&ev, chart)
}
Duchon => {
let centers = duchon_centers_from_atom(atom);
let conservative_m = m.max(1);
let ev = DuchonCoordinateEvaluator::new(centers, conservative_m)?;
JetSups::from_family(&ev, chart)
}
Precomputed(name) => {
return Err(format!(
"EncodeAtlas: precomputed basis '{name}' has no closed-form jet sup; route to exact encode"
));
}
};
Ok(sups)
}
pub(crate) fn euclidean_patch_degree(latent_dim: usize, m: usize) -> usize {
let mut degree = 0usize;
while patch_column_count(latent_dim, degree) < m && degree < m {
degree += 1;
}
degree
}
pub(crate) fn integer_root(n: usize, k: usize) -> usize {
if k == 0 {
return 1;
}
if k == 1 {
return n;
}
let mut a = 1usize;
loop {
let next = a + 1;
let mut pow: u128 = 1;
let mut overflow = false;
for _ in 0..k {
pow = pow.saturating_mul(next as u128);
if pow > n as u128 {
overflow = true;
break;
}
}
if overflow {
return a;
}
a = next;
}
}
pub(crate) fn patch_column_count(latent_dim: usize, degree: usize) -> usize {
let mut num = 1u128;
let mut den = 1u128;
for i in 1..=degree {
num *= (latent_dim + i) as u128;
den *= i as u128;
}
(num / den) as usize
}
pub(crate) fn duchon_centers_from_atom(atom: &SaeManifoldAtom) -> Array2<f64> {
Array2::<f64>::zeros((1, atom.latent_dim.max(1)))
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct JetSups {
pub(crate) value: f64,
pub(crate) jacobian: f64,
pub(crate) hessian: f64,
pub(crate) third: f64,
}
impl JetSups {
pub(crate) fn from_family<B: BasisHessianLipschitz>(family: &B, chart: &ChartRegion) -> Self {
Self {
value: family.value_sup(chart),
jacobian: family.jacobian_sup(chart),
hessian: family.hessian_sup(chart),
third: family.third_sup(chart),
}
}
}
pub(crate) fn encode_grad_hess(
atom: &SaeManifoldAtom,
evaluator: &dyn SaeBasisEvaluator,
t: ArrayView1<'_, f64>,
x: ArrayView1<'_, f64>,
amplitude: f64,
ridge: f64,
) -> Result<Option<(Array1<f64>, Array2<f64>)>, String> {
let d = atom.latent_dim;
let p = atom.output_dim();
let m = atom.basis_size();
let coords = t.to_shape((1, d)).map_err(|e| e.to_string())?.to_owned();
let (phi, jet) = evaluator.evaluate(coords.view())?;
if phi.dim() != (1, m) {
return Err(format!(
"encode_grad_hess: evaluator returned phi {:?}, expected (1, {m})",
phi.dim()
));
}
let decoder = &atom.decoder_coefficients;
let mut recon = Array1::<f64>::zeros(p);
for basis_col in 0..m {
let phi_v = phi[[0, basis_col]];
if phi_v == 0.0 {
continue;
}
for out in 0..p {
recon[out] += amplitude * phi_v * decoder[[basis_col, out]];
}
}
let residual = &recon - &x;
let mut jm = Array2::<f64>::zeros((d, p));
for axis in 0..d {
for basis_col in 0..m {
let dphi = jet[[0, basis_col, axis]];
if dphi == 0.0 {
continue;
}
for out in 0..p {
jm[[axis, out]] += amplitude * dphi * decoder[[basis_col, out]];
}
}
}
let second = match evaluator.second_jet_dyn(coords.view()) {
Some(result) => result?,
None => return Ok(None),
};
let mut g = Array1::<f64>::zeros(d);
let mut h = Array2::<f64>::zeros((d, d));
for a in 0..d {
let ja = jm.row(a);
g[a] = ja.dot(&residual);
for b in 0..d {
let mut hab = ja.dot(&jm.row(b));
let mut curv = 0.0;
for basis_col in 0..m {
let d2phi = second[[0, basis_col, a, b]];
if d2phi == 0.0 {
continue;
}
let mut dot = 0.0;
for out in 0..p {
dot += residual[out] * decoder[[basis_col, out]];
}
curv += amplitude * d2phi * dot;
}
hab += curv;
h[[a, b]] = hab;
}
}
for a in 0..d {
h[[a, a]] += ridge;
}
Ok(Some((g, h)))
}
pub(crate) fn beta_eta_newton(
h: ArrayView2<'_, f64>,
g: ArrayView1<'_, f64>,
) -> Result<Option<(f64, f64, Array1<f64>)>, String> {
let (vals, vecs) = h
.eigh(Side::Lower)
.map_err(|e| format!("beta_eta_newton: eigh failed: {e:?}"))?;
let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
if !(lambda_min.is_finite() && lambda_min > 0.0) {
return Ok(None);
}
let beta = 1.0 / lambda_min;
let d = h.nrows();
let mut delta = Array1::<f64>::zeros(d);
for (col, &lam) in vals.iter().enumerate() {
if lam <= 0.0 {
return Ok(None);
}
let vi = vecs.column(col);
let coeff = vi.dot(&g) / lam;
for row in 0..d {
delta[row] -= coeff * vi[row];
}
}
let eta = delta.dot(&delta).sqrt();
Ok(Some((beta, eta, delta)))
}
pub fn row_certificate(
atom: &SaeManifoldAtom,
evaluator: &dyn SaeBasisEvaluator,
t0: ArrayView1<'_, f64>,
x: ArrayView1<'_, f64>,
amplitude: f64,
lipschitz: f64,
ridge: f64,
) -> Result<(RowCertificate, Array1<f64>), String> {
let uncertified = || {
(
RowCertificate {
beta: f64::INFINITY,
eta: f64::INFINITY,
lipschitz,
h: f64::INFINITY,
},
Array1::<f64>::zeros(atom.latent_dim),
)
};
let Some((g, h)) = encode_grad_hess(atom, evaluator, t0, x, amplitude, ridge)? else {
return Ok(uncertified());
};
match beta_eta_newton(h.view(), g.view())? {
Some((beta, eta, delta)) => {
let cert = RowCertificate {
beta,
eta,
lipschitz,
h: beta * eta * lipschitz,
};
Ok((cert, delta))
}
None => Ok(uncertified()),
}
}
#[derive(Debug, Clone, Copy)]
pub struct AtlasConfig {
pub grid_resolution: usize,
pub ridge: f64,
pub newton_steps: usize,
}
impl Default for AtlasConfig {
fn default() -> Self {
Self {
grid_resolution: 16,
ridge: 1.0e-9,
newton_steps: 2,
}
}
}
#[derive(Debug, Clone)]
pub struct EncodeAtlas {
pub atoms: Vec<AtomEncodeAtlas>,
pub config: AtlasConfig,
}
impl EncodeAtlas {
pub fn build(
atoms: &[SaeManifoldAtom],
amplitude_bound: &[f64],
target_norm_bound: f64,
config: AtlasConfig,
) -> Result<Self, String> {
if amplitude_bound.len() != atoms.len() {
return Err(format!(
"EncodeAtlas::build: amplitude_bound length {} != atom count {}",
amplitude_bound.len(),
atoms.len()
));
}
let mut atom_atlases = Vec::with_capacity(atoms.len());
for (k, atom) in atoms.iter().enumerate() {
let atlas =
Self::build_atom_atlas(k, atom, amplitude_bound[k], target_norm_bound, &config)?;
atom_atlases.push(atlas);
}
Ok(Self {
atoms: atom_atlases,
config,
})
}
pub(crate) fn build_atom_atlas(
atom_index: usize,
atom: &SaeManifoldAtom,
amplitude_bound: f64,
target_norm_bound: f64,
config: &AtlasConfig,
) -> Result<AtomEncodeAtlas, String> {
let d = atom.latent_dim;
let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
let centers = chart_center_grid(atom, config.grid_resolution);
let nominal_radius = chart_nominal_radius(atom, config.grid_resolution);
let mut charts = Vec::with_capacity(centers.nrows());
for c in 0..centers.nrows() {
let center = centers.row(c).to_owned();
let region = chart_region(atom, center.clone(), nominal_radius);
let sups = family_jet_sups(atom, ®ion)?;
let recon_sups = reconstruction_jet_sups(atom, sups);
let lipschitz =
hessian_lipschitz_constant(recon_sups, amplitude_bound, target_norm_bound, 0.0);
let beta_center = match center_beta(atom, ¢er, config.ridge) {
Some(b) => b,
None => {
charts.push(CertifiedChart {
region,
lipschitz,
beta_center: f64::INFINITY,
certified_radius: 0.0,
amortized_jacobian: None,
recon_center: Array1::<f64>::zeros(atom.output_dim()),
});
continue;
}
};
let (amortized_jacobian, recon_center) =
match center_amortized_jacobian(atom, ¢er, config.ridge) {
Some((a1, m1)) => (Some(a1), m1),
None => (None, Array1::<f64>::zeros(atom.output_dim())),
};
let certified_radius = if lipschitz > 0.0 && beta_center.is_finite() {
(0.5 / (beta_center * lipschitz)).min(region.radius)
} else {
region.radius
};
charts.push(CertifiedChart {
region,
lipschitz,
beta_center,
certified_radius,
amortized_jacobian,
recon_center,
});
}
Ok(AtomEncodeAtlas {
atom_index,
latent_dim: d,
decoder_norm_sum,
charts,
})
}
pub fn certified_encode_row(
&self,
atom: &SaeManifoldAtom,
atom_index: usize,
x: ArrayView1<'_, f64>,
amplitude: f64,
) -> Result<(Array1<f64>, RowCertificate), String> {
let atom_atlas = self
.atoms
.get(atom_index)
.ok_or_else(|| format!("certified_encode_row: atom {atom_index} not in atlas"))?;
let evaluator = atom
.basis_evaluator
.as_ref()
.ok_or_else(|| format!("certified_encode_row: atom {atom_index} has no evaluator"))?
.clone();
let d = atom.latent_dim;
let Some((chart_idx, _)) = nearest_chart(atom_atlas, x, atom, evaluator.as_ref()) else {
return Ok((
Array1::<f64>::zeros(d),
RowCertificate {
beta: f64::INFINITY,
eta: f64::INFINITY,
lipschitz: f64::INFINITY,
h: f64::INFINITY,
},
));
};
let chart = &atom_atlas.charts[chart_idx];
let mut t = chart.region.center.clone();
let (cert, mut delta) = row_certificate(
atom,
evaluator.as_ref(),
t.view(),
x,
amplitude,
chart.lipschitz,
self.config.ridge,
)?;
if !cert.certified() {
return Ok((t, cert));
}
for step in 0..self.config.newton_steps {
t = &t + δ
if step + 1 < self.config.newton_steps {
let (_c, next_delta) = row_certificate(
atom,
evaluator.as_ref(),
t.view(),
x,
amplitude,
chart.lipschitz,
self.config.ridge,
)?;
delta = next_delta;
}
}
Ok((t, cert))
}
pub fn amortized_encode_row(
&self,
atom: &SaeManifoldAtom,
atom_index: usize,
x: ArrayView1<'_, f64>,
amplitude: f64,
) -> Result<(Array1<f64>, RowCertificate), String> {
let atom_atlas = self
.atoms
.get(atom_index)
.ok_or_else(|| format!("amortized_encode_row: atom {atom_index} not in atlas"))?;
let evaluator = atom
.basis_evaluator
.as_ref()
.ok_or_else(|| format!("amortized_encode_row: atom {atom_index} has no evaluator"))?
.clone();
let d = atom.latent_dim;
let uncertified = || {
(
Array1::<f64>::zeros(d),
RowCertificate {
beta: f64::INFINITY,
eta: f64::INFINITY,
lipschitz: f64::INFINITY,
h: f64::INFINITY,
},
)
};
let Some((chart_idx, _)) = nearest_chart(atom_atlas, x, atom, evaluator.as_ref()) else {
return Ok(uncertified());
};
let chart = &atom_atlas.charts[chart_idx];
let Some(a1) = chart.amortized_jacobian.as_ref() else {
return Ok(uncertified());
};
if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
return Ok(uncertified());
}
let p = atom.output_dim();
let mut t_hat = chart.region.center.clone();
for (out_idx, &m1_out) in chart.recon_center.iter().enumerate().take(p) {
let resid = x[out_idx] - amplitude * m1_out;
for axis in 0..d {
t_hat[axis] += a1[[axis, out_idx]] * resid / amplitude;
}
}
let (cert, mut delta) = row_certificate(
atom,
evaluator.as_ref(),
t_hat.view(),
x,
amplitude,
chart.lipschitz,
self.config.ridge,
)?;
if !cert.certified() {
return Ok((t_hat, cert));
}
for step in 0..self.config.newton_steps {
t_hat = &t_hat + δ
if step + 1 < self.config.newton_steps {
let (_c, next_delta) = row_certificate(
atom,
evaluator.as_ref(),
t_hat.view(),
x,
amplitude,
chart.lipschitz,
self.config.ridge,
)?;
delta = next_delta;
}
}
Ok((t_hat, cert))
}
pub fn amortized_encode_batch(
&self,
atom: &SaeManifoldAtom,
atom_index: usize,
targets: ArrayView2<'_, f64>,
amplitudes: ArrayView1<'_, f64>,
) -> Result<EncodeResult, String> {
let n = targets.nrows();
if amplitudes.len() != n {
return Err(format!(
"amortized_encode_batch: amplitudes len {} != rows {n}",
amplitudes.len()
));
}
let d = atom.latent_dim;
let encode_rows =
|range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
range
.map(|row| {
let (t, cert) = self.amortized_encode_row(
atom,
atom_index,
targets.row(row),
amplitudes[row],
)?;
Ok((t, cert.certified()))
})
.collect()
};
let rows: Vec<(Array1<f64>, bool)> =
if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
use rayon::prelude::*;
const CHUNK: usize = 256;
let n_chunks = n.div_ceil(CHUNK);
let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
.into_par_iter()
.map(|c| {
let start = c * CHUNK;
let end = (start + CHUNK).min(n);
encode_rows(start..end)
})
.collect::<Result<_, _>>()?;
chunked.into_iter().flatten().collect()
} else {
encode_rows(0..n)?
};
let mut coords = Array2::<f64>::zeros((n, d));
let mut certified = Vec::with_capacity(n);
for (row, (t, cert)) in rows.into_iter().enumerate() {
coords.row_mut(row).assign(&t);
certified.push(cert);
}
Ok(EncodeResult::from_rows(coords, certified))
}
pub fn certified_encode_batch(
&self,
atom: &SaeManifoldAtom,
atom_index: usize,
targets: ArrayView2<'_, f64>,
amplitudes: ArrayView1<'_, f64>,
) -> Result<EncodeResult, String> {
let n = targets.nrows();
if amplitudes.len() != n {
return Err(format!(
"certified_encode_batch: amplitudes len {} != rows {n}",
amplitudes.len()
));
}
let d = atom.latent_dim;
let encode_rows =
|range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
range
.map(|row| {
let (t, cert) = self.certified_encode_row(
atom,
atom_index,
targets.row(row),
amplitudes[row],
)?;
Ok((t, cert.certified()))
})
.collect()
};
let rows: Vec<(Array1<f64>, bool)> =
if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
use rayon::prelude::*;
const CHUNK: usize = 256;
let n_chunks = n.div_ceil(CHUNK);
let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
.into_par_iter()
.map(|c| {
let start = c * CHUNK;
let end = (start + CHUNK).min(n);
encode_rows(start..end)
})
.collect::<Result<_, _>>()?;
chunked.into_iter().flatten().collect()
} else {
encode_rows(0..n)?
};
let mut coords = Array2::<f64>::zeros((n, d));
let mut certified = Vec::with_capacity(n);
for (row, (t, cert)) in rows.into_iter().enumerate() {
coords.row_mut(row).assign(&t);
certified.push(cert);
}
Ok(EncodeResult::from_rows(coords, certified))
}
pub fn certified_encode_with_index<S: AtomFrameSketch + Sync>(
&self,
atoms: &[SaeManifoldAtom],
index: &SaeCandidateIndex,
sketch: &S,
targets: ArrayView2<'_, f64>,
amplitudes: ArrayView1<'_, f64>,
latent_dim: usize,
) -> Result<EncodeResult, String> {
let n = targets.nrows();
if amplitudes.len() != n {
return Err(format!(
"certified_encode_with_index: amplitudes len {} != rows {n}",
amplitudes.len()
));
}
let budget = auto_candidate_budget(atoms.len().max(1));
let encode_rows =
|range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
range
.map(|row| {
let proposal = index.propose(sketch, targets.row(row), budget, true);
let Some(&best_atom) = proposal.proposed.first() else {
return Ok(None);
};
let atom = atoms.get(best_atom).ok_or_else(|| {
format!(
"certified_encode_with_index: proposed atom {best_atom} out of range"
)
})?;
let (t, cert) = self.certified_encode_row(
atom,
best_atom,
targets.row(row),
amplitudes[row],
)?;
if t.len() != latent_dim {
return Err(format!(
"certified_encode_with_index: atom {best_atom} returned t.len()={} \
but declared latent_dim={latent_dim}; heterogeneous-dim \
dictionaries are not supported by this batched encode path",
t.len()
));
}
Ok(Some((t, cert.certified())))
})
.collect()
};
let rows: Vec<Option<(Array1<f64>, bool)>> =
if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
use rayon::prelude::*;
const CHUNK: usize = 256;
let n_chunks = n.div_ceil(CHUNK);
let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
.into_par_iter()
.map(|c| {
let start = c * CHUNK;
let end = (start + CHUNK).min(n);
encode_rows(start..end)
})
.collect::<Result<_, _>>()?;
chunked.into_iter().flatten().collect()
} else {
encode_rows(0..n)?
};
let mut coords = Array2::<f64>::zeros((n, latent_dim));
let mut certified = Vec::with_capacity(n);
for (row, slot) in rows.into_iter().enumerate() {
match slot {
Some((t, cert)) => {
coords.row_mut(row).assign(&t);
certified.push(cert);
}
None => certified.push(false),
}
}
Ok(EncodeResult::from_rows(coords, certified))
}
pub fn amortized_encode_with_index<S: AtomFrameSketch + Sync>(
&self,
atoms: &[SaeManifoldAtom],
index: &SaeCandidateIndex,
sketch: &S,
targets: ArrayView2<'_, f64>,
amplitudes: ArrayView1<'_, f64>,
latent_dim: usize,
) -> Result<EncodeResult, String> {
let n = targets.nrows();
if amplitudes.len() != n {
return Err(format!(
"amortized_encode_with_index: amplitudes len {} != rows {n}",
amplitudes.len()
));
}
let budget = auto_candidate_budget(atoms.len().max(1));
let encode_rows =
|range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
range
.map(|row| {
let proposal = index.propose(sketch, targets.row(row), budget, true);
let Some(&best_atom) = proposal.proposed.first() else {
return Ok(None);
};
let atom = atoms.get(best_atom).ok_or_else(|| {
format!(
"amortized_encode_with_index: proposed atom {best_atom} out of range"
)
})?;
let (t, cert) = self.amortized_encode_row(
atom,
best_atom,
targets.row(row),
amplitudes[row],
)?;
if t.len() != latent_dim {
return Err(format!(
"amortized_encode_with_index: atom {best_atom} returned t.len()={} \
but declared latent_dim={latent_dim}; heterogeneous-dim \
dictionaries are not supported by this batched encode path",
t.len()
));
}
Ok(Some((t, cert.certified())))
})
.collect()
};
let rows: Vec<Option<(Array1<f64>, bool)>> =
if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
use rayon::prelude::*;
const CHUNK: usize = 256;
let n_chunks = n.div_ceil(CHUNK);
let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
.into_par_iter()
.map(|c| {
let start = c * CHUNK;
let end = (start + CHUNK).min(n);
encode_rows(start..end)
})
.collect::<Result<_, _>>()?;
chunked.into_iter().flatten().collect()
} else {
encode_rows(0..n)?
};
let mut coords = Array2::<f64>::zeros((n, latent_dim));
let mut certified = Vec::with_capacity(n);
for (row, slot) in rows.into_iter().enumerate() {
match slot {
Some((t, cert)) => {
coords.row_mut(row).assign(&t);
certified.push(cert);
}
None => certified.push(false),
}
}
Ok(EncodeResult::from_rows(coords, certified))
}
}
pub(crate) fn center_beta(atom: &SaeManifoldAtom, center: &Array1<f64>, ridge: f64) -> Option<f64> {
let evaluator = atom.basis_evaluator.as_ref()?.clone();
let d = atom.latent_dim;
let p = atom.output_dim();
let m = atom.basis_size();
let coords = center.view().to_shape((1, d)).ok()?.to_owned();
let (_phi, jet) = evaluator.evaluate(coords.view()).ok()?;
let decoder = &atom.decoder_coefficients;
let mut jm = Array2::<f64>::zeros((d, p));
for axis in 0..d {
for basis_col in 0..m {
let dphi = jet[[0, basis_col, axis]];
if dphi == 0.0 {
continue;
}
for out in 0..p {
jm[[axis, out]] += dphi * decoder[[basis_col, out]];
}
}
}
let mut h = Array2::<f64>::zeros((d, d));
for a in 0..d {
for b in 0..d {
h[[a, b]] = jm.row(a).dot(&jm.row(b));
}
h[[a, a]] += ridge;
}
let (vals, _vecs) = h.eigh(Side::Lower).ok()?;
let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
if lambda_min.is_finite() && lambda_min > 0.0 {
Some(1.0 / lambda_min)
} else {
None
}
}
pub(crate) fn center_amortized_jacobian(
atom: &SaeManifoldAtom,
center: &Array1<f64>,
ridge: f64,
) -> Option<(Array2<f64>, Array1<f64>)> {
let evaluator = atom.basis_evaluator.as_ref()?.clone();
let d = atom.latent_dim;
let p = atom.output_dim();
let m = atom.basis_size();
let coords = center.view().to_shape((1, d)).ok()?.to_owned();
let (phi, jet) = evaluator.evaluate(coords.view()).ok()?;
let decoder = &atom.decoder_coefficients;
let mut recon = Array1::<f64>::zeros(p);
for basis_col in 0..m {
let phi_v = phi[[0, basis_col]];
if phi_v == 0.0 {
continue;
}
for out in 0..p {
recon[out] += phi_v * decoder[[basis_col, out]];
}
}
let mut jm = Array2::<f64>::zeros((d, p));
for axis in 0..d {
for basis_col in 0..m {
let dphi = jet[[0, basis_col, axis]];
if dphi == 0.0 {
continue;
}
for out in 0..p {
jm[[axis, out]] += dphi * decoder[[basis_col, out]];
}
}
}
let mut h = Array2::<f64>::zeros((d, d));
for a in 0..d {
for b in 0..d {
h[[a, b]] = jm.row(a).dot(&jm.row(b));
}
h[[a, a]] += ridge;
}
let (vals, vecs) = h.eigh(Side::Lower).ok()?;
let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
if !(lambda_min.is_finite() && lambda_min > 0.0) {
return None;
}
let mut a1 = Array2::<f64>::zeros((d, p));
for out in 0..p {
let jcol = jm.column(out);
for (i, &lam) in vals.iter().enumerate() {
if !(lam.is_finite() && lam > 0.0) {
return None;
}
let vi = vecs.column(i);
let coeff = vi.dot(&jcol) / lam;
for row in 0..d {
a1[[row, out]] += coeff * vi[row];
}
}
}
Some((a1, recon))
}
pub(crate) fn nearest_chart(
atom_atlas: &AtomEncodeAtlas,
x: ArrayView1<'_, f64>,
atom: &SaeManifoldAtom,
evaluator: &dyn SaeBasisEvaluator,
) -> Option<(usize, f64)> {
if atom_atlas.charts.is_empty() {
return None;
}
let d = atom.latent_dim;
let p = atom.output_dim();
let m = atom.basis_size();
let mut best: Option<(usize, f64)> = None;
for (idx, chart) in atom_atlas.charts.iter().enumerate() {
if chart.certified_radius <= 0.0 {
continue;
}
let coords = match chart.region.center.view().to_shape((1, d)) {
Ok(c) => c.to_owned(),
Err(_) => continue,
};
let Ok((phi, _jet)) = evaluator.evaluate(coords.view()) else {
continue;
};
let mut recon = Array1::<f64>::zeros(p);
for basis_col in 0..m {
let phi_v = phi[[0, basis_col]];
if phi_v == 0.0 {
continue;
}
for out in 0..p {
recon[out] += phi_v * atom.decoder_coefficients[[basis_col, out]];
}
}
let diff = &recon - &x;
let dist = diff.dot(&diff);
if best.map(|(_, b)| dist < b).unwrap_or(true) {
best = Some((idx, dist));
}
}
best
}
pub(crate) const SHAPE_BAND_MAX_POINTS: usize = 512;
pub(crate) fn chart_center_grid(atom: &SaeManifoldAtom, resolution: usize) -> Array2<f64> {
use crate::terms::sae_manifold::SaeAtomBasisKind::*;
let d = atom.latent_dim;
match &atom.basis_kind {
Periodic | Torus => regular_product_grid(d, resolution, 0.0, 1.0, false),
Cylinder if d == 2 => cylinder_chart_center_grid(resolution),
Cylinder => regular_product_grid(d, resolution, -0.5, 0.5, true),
Sphere if d == 2 => sphere_latlon_grid(resolution),
Sphere | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
regular_product_grid(d, resolution, -0.5, 0.5, true)
}
}
}
pub(crate) fn regular_product_grid(
d: usize,
resolution: usize,
lo: f64,
hi: f64,
include_endpoint: bool,
) -> Array2<f64> {
if d == 0 {
return Array2::<f64>::zeros((1, 0));
}
let mut per_axis = resolution.max(2);
while per_axis.saturating_pow(d as u32) > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
per_axis -= 1;
}
let total = per_axis.saturating_pow(d as u32).max(1);
let denom = if include_endpoint {
(per_axis.max(2) - 1) as f64
} else {
per_axis as f64
};
let mut grid = Array2::<f64>::zeros((total, d));
let mut idx = vec![0usize; d];
for flat in 0..total {
for axis in 0..d {
let frac = idx[axis] as f64 / denom;
grid[[flat, axis]] = lo + (hi - lo) * frac;
}
for axis in (0..d).rev() {
idx[axis] += 1;
if idx[axis] < per_axis {
break;
}
idx[axis] = 0;
}
}
grid
}
pub(crate) fn sphere_latlon_grid(resolution: usize) -> Array2<f64> {
use std::f64::consts::PI;
let r = resolution.max(2).min(22); let mut grid = Array2::<f64>::zeros((r * r, 2));
for i in 0..r {
let lat = -PI / 2.0 + PI * (i as f64 + 0.5) / r as f64;
for j in 0..r {
let lon = -PI + 2.0 * PI * (j as f64) / r as f64;
grid[[i * r + j, 0]] = lat;
grid[[i * r + j, 1]] = lon;
}
}
grid
}
pub(crate) fn cylinder_chart_center_grid(resolution: usize) -> Array2<f64> {
let mut per_axis = resolution.max(2);
while per_axis * per_axis > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
per_axis -= 1;
}
let total = per_axis * per_axis;
let line_denom = (per_axis.max(2) - 1) as f64;
let mut grid = Array2::<f64>::zeros((total, 2));
for i in 0..per_axis {
let circle = i as f64 / per_axis as f64;
for j in 0..per_axis {
let line = -0.5 + (j as f64) / line_denom;
grid[[i * per_axis + j, 0]] = circle;
grid[[i * per_axis + j, 1]] = line;
}
}
grid
}
pub(crate) fn chart_nominal_radius(atom: &SaeManifoldAtom, resolution: usize) -> f64 {
use crate::terms::sae_manifold::SaeAtomBasisKind::*;
match &atom.basis_kind {
Periodic | Torus => 0.5 / (resolution.max(2) as f64),
Sphere => std::f64::consts::PI / (resolution.max(2) as f64),
Cylinder => 0.5 / (resolution.max(2) as f64),
Duchon | EuclideanPatch | Poincare | Precomputed(_) => 1.0 / (resolution.max(2) as f64),
}
}
pub(crate) fn chart_region(
atom: &SaeManifoldAtom,
center: Array1<f64>,
radius: f64,
) -> ChartRegion {
use crate::terms::sae_manifold::SaeAtomBasisKind::*;
let region = ChartRegion::new(center.clone(), radius);
match &atom.basis_kind {
Duchon => {
let center_norm = center.dot(¢er).sqrt();
let r_min = (center_norm - radius).max(f64::MIN_POSITIVE);
let r_max = center_norm + radius;
region.with_radial_bounds(r_min, r_max)
}
Periodic | Sphere | Torus | Cylinder | EuclideanPatch | Poincare | Precomputed(_) => region,
}
}