use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use gam_linalg::faer_ndarray::FaerEigh;
use crate::candidate_index::{
AtomFrameSketch, SaeCandidateIndex, auto_candidate_budget,
};
use crate::manifold::{
AffineCoordinateEvaluator, CylinderHarmonicEvaluator, DuchonCoordinateEvaluator,
EuclideanPatchEvaluator, PeriodicHarmonicEvaluator, SaeBasisEvaluator, SaeManifoldAtom,
SphereChartEvaluator, TorusHarmonicEvaluator,
};
use faer::Side;
pub const KANTOROVICH_THRESHOLD: f64 = 0.5;
pub(crate) const ENCODE_BATCH_PARALLEL_ROW_MIN: usize = 256;
pub(crate) const CANDIDATE_ROUTING_MIN_ALIGNMENT: f64 = 0.5;
pub(crate) const CERTIFIED_ROUTING_TOPK: usize = 4;
#[derive(Debug, Clone)]
pub struct ChartRegion {
pub center: Array1<f64>,
pub radius: f64,
pub exclusion_r_min: Option<f64>,
pub radial_r_max: Option<f64>,
}
impl ChartRegion {
pub fn new(center: Array1<f64>, radius: f64) -> Self {
Self {
center,
radius,
exclusion_r_min: None,
radial_r_max: None,
}
}
pub fn with_radial_bounds(mut self, r_min: f64, r_max: f64) -> Self {
self.exclusion_r_min = Some(r_min);
self.radial_r_max = Some(r_max);
self
}
pub(crate) fn assert_valid(&self) {
assert!(
self.radius.is_finite()
&& self.radius >= 0.0
&& self.center.iter().all(|c| c.is_finite()),
"ChartRegion must have a finite center and a finite non-negative radius"
);
}
}
pub trait BasisHessianLipschitz {
fn value_sup(&self, chart: &ChartRegion) -> f64;
fn jacobian_sup(&self, chart: &ChartRegion) -> f64;
fn hessian_sup(&self, chart: &ChartRegion) -> f64;
fn third_sup(&self, chart: &ChartRegion) -> f64;
}
pub(crate) fn harmonic_jet_sup(num_basis: usize, order: u32) -> f64 {
let top_harmonic = num_basis.saturating_sub(1) / 2;
let omega = std::f64::consts::TAU * top_harmonic as f64;
omega.powi(order as i32)
}
impl BasisHessianLipschitz for PeriodicHarmonicEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
1.0
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
harmonic_jet_sup(self.num_basis, 1)
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
harmonic_jet_sup(self.num_basis, 2)
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
harmonic_jet_sup(self.num_basis, 3)
}
}
impl BasisHessianLipschitz for TorusHarmonicEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
1.0
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
torus_jet_sup(self.num_harmonics, self.latent_dim, 1)
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
torus_jet_sup(self.num_harmonics, self.latent_dim, 2)
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
torus_jet_sup(self.num_harmonics, self.latent_dim, 3)
}
}
pub(crate) fn torus_jet_sup(num_harmonics: usize, latent_dim: usize, order: u32) -> f64 {
let omega = std::f64::consts::TAU * num_harmonics as f64;
omega.powi(order as i32) * (latent_dim as f64).powi(order as i32)
}
impl BasisHessianLipschitz for SphereChartEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
1.0
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
4.0
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
16.0
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
64.0
}
}
impl BasisHessianLipschitz for AffineCoordinateEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
let center_norm = chart.center.dot(&chart.center).sqrt();
1.0 + center_norm + chart.radius
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
1.0
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
0.0
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
chart.assert_valid();
0.0
}
}
impl BasisHessianLipschitz for EuclideanPatchEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
let rho = patch_rho(chart);
let d = self.max_degree as i32;
rho.powi(d).max(1.0)
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
patch_jet_sup(self.latent_dim, self.max_degree, chart, 1)
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
patch_jet_sup(self.latent_dim, self.max_degree, chart, 2)
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
patch_jet_sup(self.latent_dim, self.max_degree, chart, 3)
}
}
impl BasisHessianLipschitz for CylinderHarmonicEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 0)
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 1)
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 2)
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
cylinder_jet_sup(self.circle_harmonics, self.line_degree, chart, 3)
}
}
pub(crate) fn cylinder_jet_sup(
circle_harmonics: usize,
line_degree: usize,
chart: &ChartRegion,
order: u32,
) -> f64 {
let omega = std::f64::consts::TAU * circle_harmonics as f64;
let big_d = line_degree as f64;
let rho = patch_rho(chart);
let mut best = 0.0_f64;
for k0 in 0..=order {
let k1 = order - k0;
let circle = if k0 == 0 { 1.0 } else { omega.powi(k0 as i32) };
let line = if k1 == 0 {
rho.powi(line_degree as i32).max(1.0)
} else {
let residual = line_degree.saturating_sub(k1 as usize) as i32;
big_d.powi(k1 as i32) * rho.powi(residual).max(1.0)
};
best = best.max(circle * line);
}
best
}
pub(crate) fn patch_rho(chart: &ChartRegion) -> f64 {
let center_inf = chart
.center
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()));
center_inf + chart.radius
}
pub(crate) fn patch_jet_sup(
latent_dim: usize,
max_degree: usize,
chart: &ChartRegion,
order: u32,
) -> f64 {
let d = latent_dim as f64;
let big_d = max_degree as f64;
let rho = patch_rho(chart);
let residual_degree = max_degree.saturating_sub(order as usize) as i32;
d.powi(order as i32) * big_d.powi(order as i32) * rho.powi(residual_degree).max(1.0)
}
impl BasisHessianLipschitz for DuchonCoordinateEvaluator {
fn value_sup(&self, chart: &ChartRegion) -> f64 {
let r_max = chart.radial_r_max.unwrap_or(chart.radius);
let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 0);
(r_max.powi(3)).max(poly)
}
fn jacobian_sup(&self, chart: &ChartRegion) -> f64 {
let r_max = chart.radial_r_max.unwrap_or(chart.radius);
let kernel = 3.0 * r_max * r_max;
let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 1);
kernel.max(poly)
}
fn hessian_sup(&self, chart: &ChartRegion) -> f64 {
let r_max = chart.radial_r_max.unwrap_or(chart.radius);
let r_min = chart
.exclusion_r_min
.unwrap_or(chart.radius)
.max(f64::MIN_POSITIVE);
let kernel = 6.0 * r_max + 3.0 * r_max * r_max / r_min;
let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 2);
kernel.max(poly)
}
fn third_sup(&self, chart: &ChartRegion) -> f64 {
let r_max = chart.radial_r_max.unwrap_or(chart.radius);
let r_min = chart
.exclusion_r_min
.unwrap_or(chart.radius)
.max(f64::MIN_POSITIVE);
let kernel = 6.0 + 18.0 * r_max / r_min + 9.0 * r_max * r_max / (r_min * r_min);
let poly = duchon_poly_jet_sup(self.centers.ncols(), self.order_degree(), chart, 3);
kernel.max(poly)
}
}
trait DuchonOrderDegree {
fn order_degree(&self) -> usize;
}
impl DuchonOrderDegree for DuchonCoordinateEvaluator {
fn order_degree(&self) -> usize {
match self.order {
gam_terms::basis::DuchonNullspaceOrder::Zero => 0,
gam_terms::basis::DuchonNullspaceOrder::Linear => 1,
gam_terms::basis::DuchonNullspaceOrder::Degree(d) => d,
}
}
}
pub(crate) fn duchon_poly_jet_sup(
latent_dim: usize,
order_degree: usize,
chart: &ChartRegion,
order: u32,
) -> f64 {
if order_degree == 0 {
return if order == 0 { 1.0 } else { 0.0 };
}
patch_jet_sup(latent_dim, order_degree, chart, order)
}
pub(crate) fn decoder_row_norm_sum(decoder: ArrayView2<'_, f64>) -> f64 {
let mut acc = 0.0;
for row in decoder.rows() {
acc += row.dot(&row).sqrt();
}
acc
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct ReconstructionJetSups {
pub(crate) value: f64,
pub(crate) jacobian: f64,
pub(crate) hessian: f64,
pub(crate) third: f64,
}
pub(crate) fn pair_trig_decoder_sup(
sin_row: ArrayView1<'_, f64>,
cos_row: ArrayView1<'_, f64>,
) -> f64 {
let aa = sin_row.dot(&sin_row);
let bb = cos_row.dot(&cos_row);
let ab = sin_row.dot(&cos_row);
let trace = aa + bb;
let disc = ((aa - bb) * (aa - bb) + 4.0 * ab * ab).sqrt();
(0.5 * (trace + disc)).sqrt()
}
pub(crate) fn periodic_reconstruction_jet_sups(
decoder: ArrayView2<'_, f64>,
) -> ReconstructionJetSups {
let mut value = 0.0;
let mut jacobian = 0.0;
let mut hessian = 0.0;
let mut third = 0.0;
if decoder.nrows() > 0 {
value += decoder.row(0).dot(&decoder.row(0)).sqrt();
}
let harmonics = decoder.nrows().saturating_sub(1) / 2;
for h in 1..=harmonics {
let sin_idx = 2 * h - 1;
let cos_idx = 2 * h;
let amp = pair_trig_decoder_sup(decoder.row(sin_idx), decoder.row(cos_idx));
let omega = std::f64::consts::TAU * h as f64;
value += amp;
jacobian += omega * amp;
hessian += omega.powi(2) * amp;
third += omega.powi(3) * amp;
}
for row in (1 + 2 * harmonics)..decoder.nrows() {
let amp = decoder.row(row).dot(&decoder.row(row)).sqrt();
value += amp;
let omega = std::f64::consts::TAU * harmonics.max(1) as f64;
jacobian += omega * amp;
hessian += omega.powi(2) * amp;
third += omega.powi(3) * amp;
}
ReconstructionJetSups {
value,
jacobian,
hessian,
third,
}
}
pub(crate) fn reconstruction_jet_sups(
atom: &SaeManifoldAtom,
sups: JetSups,
) -> ReconstructionJetSups {
if matches!(
atom.basis_kind,
crate::manifold::SaeAtomBasisKind::Periodic
) {
periodic_reconstruction_jet_sups(atom.decoder_coefficients.view())
} else {
let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
ReconstructionJetSups {
value: decoder_norm_sum * sups.value,
jacobian: decoder_norm_sum * sups.jacobian,
hessian: decoder_norm_sum * sups.hessian,
third: decoder_norm_sum * sups.third,
}
}
}
pub(crate) fn hessian_lipschitz_constant(
recon_sups: ReconstructionJetSups,
amplitude: f64,
target_norm: f64,
prior_lipschitz: f64,
) -> f64 {
let z = amplitude.abs();
let m_jac = z * recon_sups.jacobian;
let m_hess = z * recon_sups.hessian;
let m_third = z * recon_sups.third;
let recon_value = z * recon_sups.value;
let r_norm = target_norm + recon_value;
3.0 * m_jac * m_hess + r_norm * m_third + prior_lipschitz
}
#[derive(Debug, Clone)]
pub struct CertifiedChart {
pub region: ChartRegion,
pub lipschitz: f64,
pub beta_center: f64,
pub certified_radius: f64,
pub amortized_jacobian: Option<Array2<f64>>,
pub recon_center: Array1<f64>,
}
#[derive(Debug, Clone)]
pub struct AtomEncodeAtlas {
pub atom_index: usize,
pub latent_dim: usize,
pub decoder_norm_sum: f64,
pub charts: Vec<CertifiedChart>,
}
#[derive(Debug, Clone)]
pub struct EncodeResult {
pub coords: Array2<f64>,
pub certified: Vec<bool>,
pub encode_uncertified_count: usize,
}
impl EncodeResult {
pub(crate) fn from_rows(coords: Array2<f64>, certified: Vec<bool>) -> Self {
let encode_uncertified_count = certified.iter().filter(|c| !**c).count();
Self {
coords,
certified,
encode_uncertified_count,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct RowCertificate {
pub beta: f64,
pub eta: f64,
pub lipschitz: f64,
pub h: f64,
}
impl RowCertificate {
pub fn certified(&self) -> bool {
self.h.is_finite() && self.h <= KANTOROVICH_THRESHOLD
}
}
#[derive(Debug, Clone)]
struct CertifiedEncodeProbe {
coord: Array1<f64>,
initial_cert: RowCertificate,
final_cert: RowCertificate,
}
pub(crate) const SAE_CYLINDER_LINE_DEGREE: usize = 2;
pub(crate) fn family_jet_sups(
atom: &SaeManifoldAtom,
chart: &ChartRegion,
) -> Result<JetSups, String> {
use crate::manifold::SaeAtomBasisKind::*;
let m = atom.basis_size();
let d = atom.latent_dim;
let sups = match &atom.basis_kind {
Periodic => {
let ev = PeriodicHarmonicEvaluator::new(m)?;
JetSups::from_family(&ev, chart)
}
Torus => {
let axis_m = integer_root(m, d.max(1));
let num_harmonics = axis_m.saturating_sub(1) / 2;
let ev = TorusHarmonicEvaluator::new(d, num_harmonics.max(1))?;
JetSups::from_family(&ev, chart)
}
Sphere => {
let ev = SphereChartEvaluator;
JetSups::from_family(&ev, chart)
}
Cylinder => {
let ml = SAE_CYLINDER_LINE_DEGREE + 1;
if d != 2 || ml == 0 || m % ml != 0 {
return Err(format!(
"EncodeAtlas: Cylinder atom requires latent_dim == 2 and width divisible by {ml}; got dim={d}, m={m}"
));
}
let axis_mc = m / ml;
let h = axis_mc.saturating_sub(1) / 2;
let ev = CylinderHarmonicEvaluator::new(h.max(1), SAE_CYLINDER_LINE_DEGREE)?;
JetSups::from_family(&ev, chart)
}
Linear | EuclideanPatch | Poincare => {
let degree = euclidean_patch_degree(d, m);
let ev = EuclideanPatchEvaluator::new(d, degree)?;
JetSups::from_family(&ev, chart)
}
Duchon => {
let centers = duchon_centers_from_atom(atom);
let conservative_m = m.max(1);
let ev = DuchonCoordinateEvaluator::new(centers, conservative_m)?;
JetSups::from_family(&ev, chart)
}
Precomputed(name) => {
return Err(format!(
"EncodeAtlas: precomputed basis '{name}' has no closed-form jet sup; route to exact encode"
));
}
};
Ok(sups)
}
pub(crate) fn euclidean_patch_degree(latent_dim: usize, m: usize) -> usize {
let mut degree = 0usize;
while patch_column_count(latent_dim, degree) < m && degree < m {
degree += 1;
}
degree
}
pub(crate) fn integer_root(n: usize, k: usize) -> usize {
if k == 0 {
return 1;
}
if k == 1 {
return n;
}
let mut a = 1usize;
loop {
let next = a + 1;
let mut pow: u128 = 1;
let mut overflow = false;
for _ in 0..k {
pow = pow.saturating_mul(next as u128);
if pow > n as u128 {
overflow = true;
break;
}
}
if overflow {
return a;
}
a = next;
}
}
pub(crate) fn patch_column_count(latent_dim: usize, degree: usize) -> usize {
let mut num = 1u128;
let mut den = 1u128;
for i in 1..=degree {
num *= (latent_dim + i) as u128;
den *= i as u128;
}
(num / den) as usize
}
pub(crate) fn duchon_centers_from_atom(atom: &SaeManifoldAtom) -> Array2<f64> {
Array2::<f64>::zeros((1, atom.latent_dim.max(1)))
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct JetSups {
pub(crate) value: f64,
pub(crate) jacobian: f64,
pub(crate) hessian: f64,
pub(crate) third: f64,
}
impl JetSups {
pub(crate) fn from_family<B: BasisHessianLipschitz>(family: &B, chart: &ChartRegion) -> Self {
Self {
value: family.value_sup(chart),
jacobian: family.jacobian_sup(chart),
hessian: family.hessian_sup(chart),
third: family.third_sup(chart),
}
}
}
pub(crate) fn encode_grad_hess(
atom: &SaeManifoldAtom,
evaluator: &dyn SaeBasisEvaluator,
t: ArrayView1<'_, f64>,
x: ArrayView1<'_, f64>,
amplitude: f64,
ridge: f64,
) -> Result<Option<(Array1<f64>, Array2<f64>)>, String> {
let d = atom.latent_dim;
let p = atom.output_dim();
let m = atom.basis_size();
let coords = t.to_shape((1, d)).map_err(|e| e.to_string())?.to_owned();
let (phi, jet) = evaluator.evaluate(coords.view())?;
if phi.dim() != (1, m) {
return Err(format!(
"encode_grad_hess: evaluator returned phi {:?}, expected (1, {m})",
phi.dim()
));
}
let decoder = &atom.decoder_coefficients;
let mut recon = Array1::<f64>::zeros(p);
for basis_col in 0..m {
let phi_v = phi[[0, basis_col]];
if phi_v == 0.0 {
continue;
}
for out in 0..p {
recon[out] += amplitude * phi_v * decoder[[basis_col, out]];
}
}
let residual = &recon - &x;
let mut jm = Array2::<f64>::zeros((d, p));
for axis in 0..d {
for basis_col in 0..m {
let dphi = jet[[0, basis_col, axis]];
if dphi == 0.0 {
continue;
}
for out in 0..p {
jm[[axis, out]] += amplitude * dphi * decoder[[basis_col, out]];
}
}
}
let second = match evaluator.second_jet_dyn(coords.view()) {
Some(result) => result?,
None => return Ok(None),
};
let mut g = Array1::<f64>::zeros(d);
let mut h = Array2::<f64>::zeros((d, d));
for a in 0..d {
let ja = jm.row(a);
g[a] = ja.dot(&residual);
for b in 0..d {
let mut hab = ja.dot(&jm.row(b));
let mut curv = 0.0;
for basis_col in 0..m {
let d2phi = second[[0, basis_col, a, b]];
if d2phi == 0.0 {
continue;
}
let mut dot = 0.0;
for out in 0..p {
dot += residual[out] * decoder[[basis_col, out]];
}
curv += amplitude * d2phi * dot;
}
hab += curv;
h[[a, b]] = hab;
}
}
for a in 0..d {
h[[a, a]] += ridge;
}
Ok(Some((g, h)))
}
pub(crate) fn beta_eta_newton(
h: ArrayView2<'_, f64>,
g: ArrayView1<'_, f64>,
) -> Result<Option<(f64, f64, Array1<f64>)>, String> {
let (vals, vecs) = h
.eigh(Side::Lower)
.map_err(|e| format!("beta_eta_newton: eigh failed: {e:?}"))?;
let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
if !(lambda_min.is_finite() && lambda_min > 0.0) {
return Ok(None);
}
let beta = 1.0 / lambda_min;
let d = h.nrows();
let mut delta = Array1::<f64>::zeros(d);
for (col, &lam) in vals.iter().enumerate() {
if lam <= 0.0 {
return Ok(None);
}
let vi = vecs.column(col);
let coeff = vi.dot(&g) / lam;
for row in 0..d {
delta[row] -= coeff * vi[row];
}
}
let eta = delta.dot(&delta).sqrt();
Ok(Some((beta, eta, delta)))
}
pub fn row_certificate(
atom: &SaeManifoldAtom,
evaluator: &dyn SaeBasisEvaluator,
t0: ArrayView1<'_, f64>,
x: ArrayView1<'_, f64>,
amplitude: f64,
lipschitz: f64,
ridge: f64,
) -> Result<(RowCertificate, Array1<f64>), String> {
let uncertified = || {
(
RowCertificate {
beta: f64::INFINITY,
eta: f64::INFINITY,
lipschitz,
h: f64::INFINITY,
},
Array1::<f64>::zeros(atom.latent_dim),
)
};
let Some((g, h)) = encode_grad_hess(atom, evaluator, t0, x, amplitude, ridge)? else {
return Ok(uncertified());
};
match beta_eta_newton(h.view(), g.view())? {
Some((beta, eta, delta)) => {
let cert = RowCertificate {
beta,
eta,
lipschitz,
h: beta * eta * lipschitz,
};
Ok((cert, delta))
}
None => Ok(uncertified()),
}
}
fn uncertified_certificate(lipschitz: f64) -> RowCertificate {
RowCertificate {
beta: f64::INFINITY,
eta: f64::INFINITY,
lipschitz,
h: f64::INFINITY,
}
}
fn refine_certified_start(
atom: &SaeManifoldAtom,
evaluator: &dyn SaeBasisEvaluator,
mut t: Array1<f64>,
x: ArrayView1<'_, f64>,
amplitude: f64,
lipschitz: f64,
ridge: f64,
newton_steps: usize,
initial_cert: RowCertificate,
mut delta: Array1<f64>,
) -> Result<Option<CertifiedEncodeProbe>, String> {
assert!(initial_cert.certified());
let mut final_cert = initial_cert;
for _ in 0..newton_steps {
t = &t + δ
let (cert, next_delta) =
row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
if !cert.certified() {
return Ok(None);
}
final_cert = cert;
delta = next_delta;
}
Ok(Some(CertifiedEncodeProbe {
coord: t,
initial_cert,
final_cert,
}))
}
fn certify_with_basin_warmup(
atom: &SaeManifoldAtom,
evaluator: &dyn SaeBasisEvaluator,
t_start: Array1<f64>,
x: ArrayView1<'_, f64>,
amplitude: f64,
lipschitz: f64,
ridge: f64,
newton_steps: usize,
chart_center: ArrayView1<'_, f64>,
chart_radius: f64,
) -> Result<Option<CertifiedEncodeProbe>, String> {
let in_chart = |t: &Array1<f64>| -> bool {
let r2: f64 = t
.iter()
.zip(chart_center.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
r2 <= chart_radius * chart_radius
};
let mut t = t_start;
if !in_chart(&t) {
return Ok(None);
}
let (mut cert, mut delta) =
row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
while !cert.certified() {
if !(cert.h.is_finite() && cert.beta.is_finite() && cert.eta.is_finite()) {
return Ok(None);
}
let prev_h = cert.h;
let next = &t + δ
if !in_chart(&next) {
return Ok(None);
}
t = next;
let (next_cert, next_delta) =
row_certificate(atom, evaluator, t.view(), x, amplitude, lipschitz, ridge)?;
cert = next_cert;
delta = next_delta;
if !cert.h.is_finite() || cert.h >= prev_h {
return Ok(None);
}
}
refine_certified_start(
atom,
evaluator,
t,
x,
amplitude,
lipschitz,
ridge,
newton_steps,
cert,
delta,
)
}
fn kantorovich_root_radius(cert: RowCertificate) -> f64 {
if !cert.certified() || !(cert.eta.is_finite() && cert.eta >= 0.0) {
return f64::INFINITY;
}
if cert.eta == 0.0 {
return 0.0;
}
if !(cert.h.is_finite() && cert.h >= 0.0) {
return f64::INFINITY;
}
let h = cert.h.min(KANTOROVICH_THRESHOLD);
let discriminant = (1.0 - 2.0 * h).max(0.0).sqrt();
let radius = 2.0 * cert.eta / (1.0 + discriminant);
if radius.is_finite() {
radius
} else {
f64::INFINITY
}
}
fn distilled_probe_tolerance(
amortized: &CertifiedEncodeProbe,
cold: &CertifiedEncodeProbe,
amplitude: f64,
x: ArrayView1<'_, f64>,
) -> f64 {
let certified_radius =
kantorovich_root_radius(amortized.final_cert) + kantorovich_root_radius(cold.final_cert);
let coord_scale = amortized.coord.dot(&amortized.coord).sqrt()
+ cold.coord.dot(&cold.coord).sqrt()
+ x.dot(&x).sqrt()
+ amplitude.abs()
+ 1.0;
certified_radius + 1024.0 * f64::EPSILON * coord_scale
}
fn latent_coordinate_distance(
atom: &SaeManifoldAtom,
lhs: ArrayView1<'_, f64>,
rhs: ArrayView1<'_, f64>,
) -> f64 {
let mut acc = 0.0;
for axis in 0..lhs.len().min(rhs.len()) {
let mut diff = (lhs[axis] - rhs[axis]).abs();
if let Some(period) = latent_axis_period(atom, axis) {
let wrapped = diff.rem_euclid(period);
diff = wrapped.min(period - wrapped);
}
acc += diff * diff;
}
acc.sqrt()
}
fn latent_axis_period(atom: &SaeManifoldAtom, axis: usize) -> Option<f64> {
use crate::manifold::SaeAtomBasisKind::*;
match &atom.basis_kind {
Periodic | Torus => Some(1.0),
Cylinder if axis == 0 => Some(1.0),
Sphere if axis == 1 => Some(std::f64::consts::TAU),
_ => None,
}
}
#[derive(Debug, Clone, Copy)]
pub struct AtlasConfig {
pub grid_resolution: usize,
pub ridge: f64,
pub newton_steps: usize,
}
impl Default for AtlasConfig {
fn default() -> Self {
Self {
grid_resolution: 16,
ridge: 1.0e-9,
newton_steps: 2,
}
}
}
#[derive(Debug, Clone)]
pub struct EncodeAtlas {
pub atoms: Vec<AtomEncodeAtlas>,
pub config: AtlasConfig,
}
impl EncodeAtlas {
pub fn build(
atoms: &[SaeManifoldAtom],
amplitude_bound: &[f64],
target_norm_bound: f64,
config: AtlasConfig,
) -> Result<Self, String> {
if amplitude_bound.len() != atoms.len() {
return Err(format!(
"EncodeAtlas::build: amplitude_bound length {} != atom count {}",
amplitude_bound.len(),
atoms.len()
));
}
let mut atom_atlases = Vec::with_capacity(atoms.len());
for (k, atom) in atoms.iter().enumerate() {
let atlas =
Self::build_atom_atlas(k, atom, amplitude_bound[k], target_norm_bound, &config)?;
atom_atlases.push(atlas);
}
Ok(Self {
atoms: atom_atlases,
config,
})
}
pub(crate) fn build_atom_atlas(
atom_index: usize,
atom: &SaeManifoldAtom,
amplitude_bound: f64,
target_norm_bound: f64,
config: &AtlasConfig,
) -> Result<AtomEncodeAtlas, String> {
let centers = chart_center_grid(atom, config.grid_resolution);
let nominal_radius = chart_nominal_radius(atom, config.grid_resolution);
let radii = vec![nominal_radius; centers.nrows()];
Self::build_atom_atlas_from_centers(
atom_index,
atom,
centers.view(),
&radii,
amplitude_bound,
target_norm_bound,
config,
)
}
pub(crate) fn build_atom_atlas_from_centers(
atom_index: usize,
atom: &SaeManifoldAtom,
centers: ArrayView2<'_, f64>,
radii: &[f64],
amplitude_bound: f64,
target_norm_bound: f64,
config: &AtlasConfig,
) -> Result<AtomEncodeAtlas, String> {
let d = atom.latent_dim;
if centers.ncols() != d {
return Err(format!(
"build_atom_atlas_from_centers: centers have {} cols but atom latent_dim is {d}",
centers.ncols()
));
}
if radii.len() != centers.nrows() {
return Err(format!(
"build_atom_atlas_from_centers: {} radii != {} centers",
radii.len(),
centers.nrows()
));
}
let decoder_norm_sum = decoder_row_norm_sum(atom.decoder_coefficients.view());
let mut charts = Vec::with_capacity(centers.nrows());
for c in 0..centers.nrows() {
let center = centers.row(c).to_owned();
let nominal_radius = radii[c];
let region = chart_region(atom, center.clone(), nominal_radius);
let sups = family_jet_sups(atom, ®ion)?;
let recon_sups = reconstruction_jet_sups(atom, sups);
let lipschitz =
hessian_lipschitz_constant(recon_sups, amplitude_bound, target_norm_bound, 0.0);
let beta_center = match center_beta(atom, ¢er, config.ridge) {
Some(b) => b,
None => {
charts.push(CertifiedChart {
region,
lipschitz,
beta_center: f64::INFINITY,
certified_radius: 0.0,
amortized_jacobian: None,
recon_center: Array1::<f64>::zeros(atom.output_dim()),
});
continue;
}
};
let (amortized_jacobian, recon_center) =
match center_amortized_jacobian(atom, ¢er, config.ridge) {
Some((a1, m1)) => (Some(a1), m1),
None => (None, Array1::<f64>::zeros(atom.output_dim())),
};
let certified_radius = if lipschitz > 0.0 && beta_center.is_finite() {
(0.5 / (beta_center * lipschitz)).min(region.radius)
} else {
region.radius
};
charts.push(CertifiedChart {
region,
lipschitz,
beta_center,
certified_radius,
amortized_jacobian,
recon_center,
});
}
Ok(AtomEncodeAtlas {
atom_index,
latent_dim: d,
decoder_norm_sum,
charts,
})
}
pub fn build_data_driven(
atoms: &[SaeManifoldAtom],
coords: &[Array2<f64>],
amplitude_bound: &[f64],
target_norm_bound: f64,
max_charts: usize,
config: AtlasConfig,
) -> Result<Self, String> {
if amplitude_bound.len() != atoms.len() || coords.len() != atoms.len() {
return Err(format!(
"build_data_driven: amplitude_bound {} / coords {} must match atom count {}",
amplitude_bound.len(),
coords.len(),
atoms.len()
));
}
let mut atom_atlases = Vec::with_capacity(atoms.len());
for (k, atom) in atoms.iter().enumerate() {
let (centers, radii) =
data_driven_chart_centers(atom, coords[k].view(), max_charts.max(1))?;
let atlas = Self::build_atom_atlas_from_centers(
k,
atom,
centers.view(),
&radii,
amplitude_bound[k],
target_norm_bound,
&config,
)?;
atom_atlases.push(atlas);
}
Ok(Self {
atoms: atom_atlases,
config,
})
}
fn refine_certified_encode_start(
&self,
atom: &SaeManifoldAtom,
evaluator: &dyn SaeBasisEvaluator,
chart: &CertifiedChart,
t: Array1<f64>,
x: ArrayView1<'_, f64>,
amplitude: f64,
) -> Result<(Array1<f64>, RowCertificate), String> {
let Some(probe) = certify_with_basin_warmup(
atom,
evaluator,
t,
x,
amplitude,
chart.lipschitz,
self.config.ridge,
self.config.newton_steps,
chart.region.center.view(),
chart.region.radius,
)?
else {
return Ok((
Array1::<f64>::zeros(atom.latent_dim),
uncertified_certificate(chart.lipschitz),
));
};
Ok((probe.coord, probe.initial_cert))
}
pub fn certified_encode_row(
&self,
atom: &SaeManifoldAtom,
atom_index: usize,
x: ArrayView1<'_, f64>,
amplitude: f64,
) -> Result<(Array1<f64>, RowCertificate), String> {
let atom_atlas = self
.atoms
.get(atom_index)
.ok_or_else(|| format!("certified_encode_row: atom {atom_index} not in atlas"))?;
let d = atom.latent_dim;
let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
return Ok((
Array1::<f64>::zeros(d),
RowCertificate {
beta: f64::INFINITY,
eta: f64::INFINITY,
lipschitz: f64::INFINITY,
h: f64::INFINITY,
},
));
};
let candidates =
nearest_charts_topk(atom_atlas, x, atom, evaluator.as_ref(), CERTIFIED_ROUTING_TOPK);
if candidates.is_empty() {
return Ok((
Array1::<f64>::zeros(d),
RowCertificate {
beta: f64::INFINITY,
eta: f64::INFINITY,
lipschitz: f64::INFINITY,
h: f64::INFINITY,
},
));
}
let mut best: Option<(Array1<f64>, RowCertificate, f64)> = None;
let mut nearest_fallback: Option<(Array1<f64>, RowCertificate)> = None;
for chart_idx in candidates {
let chart = &atom_atlas.charts[chart_idx];
let Some(t) = amortized_warm_start(chart, x, amplitude) else {
if nearest_fallback.is_none() {
nearest_fallback =
Some((Array1::<f64>::zeros(d), uncertified_certificate(chart.lipschitz)));
}
continue;
};
let (coord, cert) = self.refine_certified_encode_start(
atom,
evaluator.as_ref(),
chart,
t,
x,
amplitude,
)?;
if nearest_fallback.is_none() {
nearest_fallback = Some((coord.clone(), cert.clone()));
}
if cert.certified() {
let err =
encode_reconstruction_error(atom, evaluator.as_ref(), coord.view(), x, amplitude);
if best.as_ref().map(|(_, _, e)| err < *e).unwrap_or(true) {
best = Some((coord, cert, err));
}
}
}
match best {
Some((coord, cert, _)) => Ok((coord, cert)),
None => Ok(nearest_fallback.unwrap_or_else(|| {
(
Array1::<f64>::zeros(d),
RowCertificate {
beta: f64::INFINITY,
eta: f64::INFINITY,
lipschitz: f64::INFINITY,
h: f64::INFINITY,
},
)
})),
}
}
pub fn amortized_encode_row(
&self,
atom: &SaeManifoldAtom,
atom_index: usize,
x: ArrayView1<'_, f64>,
amplitude: f64,
) -> Result<(Array1<f64>, RowCertificate), String> {
let atom_atlas = self
.atoms
.get(atom_index)
.ok_or_else(|| format!("amortized_encode_row: atom {atom_index} not in atlas"))?;
let d = atom.latent_dim;
let uncertified = || {
(
Array1::<f64>::zeros(d),
RowCertificate {
beta: f64::INFINITY,
eta: f64::INFINITY,
lipschitz: f64::INFINITY,
h: f64::INFINITY,
},
)
};
let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
return Ok(uncertified());
};
let Some((chart_idx, _)) = nearest_chart(atom_atlas, x, atom, evaluator.as_ref()) else {
return Ok(uncertified());
};
let chart = &atom_atlas.charts[chart_idx];
let Some(t_hat) = amortized_warm_start(chart, x, amplitude) else {
return Ok(uncertified());
};
let Some(amortized_probe) = certify_with_basin_warmup(
atom,
evaluator.as_ref(),
t_hat,
x,
amplitude,
chart.lipschitz,
self.config.ridge,
self.config.newton_steps,
chart.region.center.view(),
chart.region.radius,
)?
else {
return Ok((
Array1::<f64>::zeros(d),
uncertified_certificate(chart.lipschitz),
));
};
let cold_start = chart.region.center.clone();
let Some(cold_probe) = certify_with_basin_warmup(
atom,
evaluator.as_ref(),
cold_start,
x,
amplitude,
chart.lipschitz,
self.config.ridge,
self.config.newton_steps,
chart.region.center.view(),
chart.region.radius,
)?
else {
return Ok((
amortized_probe.coord,
uncertified_certificate(chart.lipschitz),
));
};
let gap =
latent_coordinate_distance(atom, amortized_probe.coord.view(), cold_probe.coord.view());
let tolerance = distilled_probe_tolerance(&amortized_probe, &cold_probe, amplitude, x);
if !(gap.is_finite() && gap <= tolerance) {
return Ok((
amortized_probe.coord,
uncertified_certificate(chart.lipschitz),
));
}
Ok((amortized_probe.coord, amortized_probe.initial_cert))
}
pub fn amortized_encode_batch(
&self,
atom: &SaeManifoldAtom,
atom_index: usize,
targets: ArrayView2<'_, f64>,
amplitudes: ArrayView1<'_, f64>,
) -> Result<EncodeResult, String> {
let n = targets.nrows();
if amplitudes.len() != n {
return Err(format!(
"amortized_encode_batch: amplitudes len {} != rows {n}",
amplitudes.len()
));
}
let d = atom.latent_dim;
let encode_rows =
|range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
range
.map(|row| {
let (t, cert) = self.amortized_encode_row(
atom,
atom_index,
targets.row(row),
amplitudes[row],
)?;
Ok((t, cert.certified()))
})
.collect()
};
let rows: Vec<(Array1<f64>, bool)> =
if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
use rayon::prelude::*;
const CHUNK: usize = 256;
let n_chunks = n.div_ceil(CHUNK);
let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
.into_par_iter()
.map(|c| {
let start = c * CHUNK;
let end = (start + CHUNK).min(n);
encode_rows(start..end)
})
.collect::<Result<_, _>>()?;
chunked.into_iter().flatten().collect()
} else {
encode_rows(0..n)?
};
let mut coords = Array2::<f64>::zeros((n, d));
let mut certified = Vec::with_capacity(n);
for (row, (t, cert)) in rows.into_iter().enumerate() {
coords.row_mut(row).assign(&t);
certified.push(cert);
}
Ok(EncodeResult::from_rows(coords, certified))
}
pub fn certified_encode_batch(
&self,
atom: &SaeManifoldAtom,
atom_index: usize,
targets: ArrayView2<'_, f64>,
amplitudes: ArrayView1<'_, f64>,
) -> Result<EncodeResult, String> {
let n = targets.nrows();
if amplitudes.len() != n {
return Err(format!(
"certified_encode_batch: amplitudes len {} != rows {n}",
amplitudes.len()
));
}
let d = atom.latent_dim;
let encode_rows =
|range: std::ops::Range<usize>| -> Result<Vec<(Array1<f64>, bool)>, String> {
range
.map(|row| {
let (t, cert) = self.certified_encode_row(
atom,
atom_index,
targets.row(row),
amplitudes[row],
)?;
Ok((t, cert.certified()))
})
.collect()
};
let rows: Vec<(Array1<f64>, bool)> =
if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
use rayon::prelude::*;
const CHUNK: usize = 256;
let n_chunks = n.div_ceil(CHUNK);
let chunked: Vec<Vec<(Array1<f64>, bool)>> = (0..n_chunks)
.into_par_iter()
.map(|c| {
let start = c * CHUNK;
let end = (start + CHUNK).min(n);
encode_rows(start..end)
})
.collect::<Result<_, _>>()?;
chunked.into_iter().flatten().collect()
} else {
encode_rows(0..n)?
};
let mut coords = Array2::<f64>::zeros((n, d));
let mut certified = Vec::with_capacity(n);
for (row, (t, cert)) in rows.into_iter().enumerate() {
coords.row_mut(row).assign(&t);
certified.push(cert);
}
Ok(EncodeResult::from_rows(coords, certified))
}
pub fn amortized_encode_batch_fast(
&self,
atom: &SaeManifoldAtom,
atom_index: usize,
x: ArrayView2<'_, f64>,
amplitudes: ArrayView1<'_, f64>,
) -> Result<(Array2<f64>, Vec<bool>), String> {
let n = x.nrows();
let p = atom.output_dim();
let d = atom.latent_dim;
if x.ncols() != p {
return Err(format!(
"amortized_encode_batch_fast: x has {} cols but atom output dim is {p}",
x.ncols()
));
}
if amplitudes.len() != n {
return Err(format!(
"amortized_encode_batch_fast: amplitudes len {} != rows {n}",
amplitudes.len()
));
}
let atom_atlas = self.atoms.get(atom_index).ok_or_else(|| {
format!("amortized_encode_batch_fast: atom {atom_index} not in atlas")
})?;
let mut coords = Array2::<f64>::zeros((n, d));
let mut valid = vec![false; n];
let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
return Ok((coords, valid));
};
let valid_charts: Vec<usize> = (0..atom_atlas.charts.len())
.filter(|&c| atom_atlas.charts[c].certified_radius > 0.0)
.collect();
if valid_charts.is_empty() {
return Ok((coords, valid));
}
let mut centers = Array2::<f64>::zeros((valid_charts.len(), d));
for (ci, &c) in valid_charts.iter().enumerate() {
centers
.row_mut(ci)
.assign(&atom_atlas.charts[c].region.center);
}
let (phi_centers, _jet) = evaluator
.evaluate(centers.view())
.map_err(|err| format!("amortized_encode_batch_fast: center eval: {err}"))?;
let recon_centers = phi_centers.dot(&atom.decoder_coefficients);
let route_idx: Vec<usize> = if valid_charts.len() == 1 {
vec![0usize; n]
} else {
let s = x.dot(&recon_centers.t()); let r_sq: Vec<f64> = (0..valid_charts.len())
.map(|c| recon_centers.row(c).dot(&recon_centers.row(c)))
.collect();
(0..n)
.map(|row| {
let mut best_c = 0usize;
let mut best_d = f64::INFINITY;
for c in 0..valid_charts.len() {
let dist = r_sq[c] - 2.0 * s[[row, c]];
if dist < best_d {
best_d = dist;
best_c = c;
}
}
best_c
})
.collect()
};
for (ci, &c) in valid_charts.iter().enumerate() {
let chart = &atom_atlas.charts[c];
let Some(a1) = chart.amortized_jacobian.as_ref() else {
continue;
};
let rows_here: Vec<usize> = (0..n)
.filter(|&row| {
route_idx[row] == ci
&& amplitudes[row].is_finite()
&& amplitudes[row].abs() > 0.0
})
.collect();
if rows_here.is_empty() {
continue;
}
let mut x_c = Array2::<f64>::zeros((rows_here.len(), p));
for (i, &row) in rows_here.iter().enumerate() {
x_c.row_mut(i).assign(&x.row(row));
}
let u = x_c.dot(&a1.t());
let m1 = &chart.recon_center;
let a1_m1 = a1.dot(m1); let base = &chart.region.center - &a1_m1; for (i, &row) in rows_here.iter().enumerate() {
let inv_z = 1.0 / amplitudes[row];
for axis in 0..d {
coords[[row, axis]] = base[axis] + u[[i, axis]] * inv_z;
}
valid[row] = true;
}
}
Ok((coords, valid))
}
pub fn amortized_reconstruct_batch_fast(
&self,
atom: &SaeManifoldAtom,
atom_index: usize,
x: ArrayView2<'_, f64>,
amplitudes: ArrayView1<'_, f64>,
) -> Result<(Array2<f64>, Vec<bool>), String> {
let n = x.nrows();
let p = atom.output_dim();
let (coords, valid) = self.amortized_encode_batch_fast(atom, atom_index, x, amplitudes)?;
let mut recon = Array2::<f64>::zeros((n, p));
let Some(evaluator) = atom.basis_evaluator.as_ref().cloned() else {
return Ok((recon, valid));
};
let (phi, _jet) = evaluator
.evaluate(coords.view())
.map_err(|err| format!("amortized_reconstruct_batch_fast: basis eval: {err}"))?;
let decoded = phi.dot(&atom.decoder_coefficients); for row in 0..n {
if !valid[row] {
continue; }
let z = amplitudes[row];
for col in 0..p {
recon[[row, col]] = z * decoded[[row, col]];
}
}
Ok((recon, valid))
}
pub fn certified_encode_with_index<S: AtomFrameSketch + Sync>(
&self,
atoms: &[SaeManifoldAtom],
index: &SaeCandidateIndex,
sketch: &S,
targets: ArrayView2<'_, f64>,
amplitudes: ArrayView1<'_, f64>,
latent_dim: usize,
) -> Result<EncodeResult, String> {
let n = targets.nrows();
if amplitudes.len() != n {
return Err(format!(
"certified_encode_with_index: amplitudes len {} != rows {n}",
amplitudes.len()
));
}
let budget = auto_candidate_budget(atoms.len().max(1));
let encode_rows =
|range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
range
.map(|row| {
let proposal = index.propose(sketch, targets.row(row), budget, true);
let Some(&best_atom) = proposal.proposed.first() else {
return Ok(None);
};
let routing_alignment = sketch.alignment(best_atom, targets.row(row));
if !routing_alignment.is_finite()
|| routing_alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
{
return Ok(None);
}
let atom = atoms.get(best_atom).ok_or_else(|| {
format!(
"certified_encode_with_index: proposed atom {best_atom} out of range"
)
})?;
let (t, cert) = self.certified_encode_row(
atom,
best_atom,
targets.row(row),
amplitudes[row],
)?;
if t.len() != latent_dim {
return Err(format!(
"certified_encode_with_index: atom {best_atom} returned t.len()={} \
but declared latent_dim={latent_dim}; heterogeneous-dim \
dictionaries are not supported by this batched encode path",
t.len()
));
}
Ok(Some((t, cert.certified())))
})
.collect()
};
let rows: Vec<Option<(Array1<f64>, bool)>> =
if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
use rayon::prelude::*;
const CHUNK: usize = 256;
let n_chunks = n.div_ceil(CHUNK);
let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
.into_par_iter()
.map(|c| {
let start = c * CHUNK;
let end = (start + CHUNK).min(n);
encode_rows(start..end)
})
.collect::<Result<_, _>>()?;
chunked.into_iter().flatten().collect()
} else {
encode_rows(0..n)?
};
let mut coords = Array2::<f64>::zeros((n, latent_dim));
let mut certified = Vec::with_capacity(n);
for (row, slot) in rows.into_iter().enumerate() {
match slot {
Some((t, cert)) => {
coords.row_mut(row).assign(&t);
certified.push(cert);
}
None => certified.push(false),
}
}
Ok(EncodeResult::from_rows(coords, certified))
}
pub fn amortized_encode_with_index<S: AtomFrameSketch + Sync>(
&self,
atoms: &[SaeManifoldAtom],
index: &SaeCandidateIndex,
sketch: &S,
targets: ArrayView2<'_, f64>,
amplitudes: ArrayView1<'_, f64>,
latent_dim: usize,
) -> Result<EncodeResult, String> {
let n = targets.nrows();
if amplitudes.len() != n {
return Err(format!(
"amortized_encode_with_index: amplitudes len {} != rows {n}",
amplitudes.len()
));
}
let budget = auto_candidate_budget(atoms.len().max(1));
let encode_rows =
|range: std::ops::Range<usize>| -> Result<Vec<Option<(Array1<f64>, bool)>>, String> {
range
.map(|row| {
let proposal = index.propose(sketch, targets.row(row), budget, true);
let Some(&best_atom) = proposal.proposed.first() else {
return Ok(None);
};
let routing_alignment = sketch.alignment(best_atom, targets.row(row));
if !routing_alignment.is_finite()
|| routing_alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
{
return Ok(None);
}
let atom = atoms.get(best_atom).ok_or_else(|| {
format!(
"amortized_encode_with_index: proposed atom {best_atom} out of range"
)
})?;
let (t, cert) = self.amortized_encode_row(
atom,
best_atom,
targets.row(row),
amplitudes[row],
)?;
if t.len() != latent_dim {
return Err(format!(
"amortized_encode_with_index: atom {best_atom} returned t.len()={} \
but declared latent_dim={latent_dim}; heterogeneous-dim \
dictionaries are not supported by this batched encode path",
t.len()
));
}
Ok(Some((t, cert.certified())))
})
.collect()
};
let rows: Vec<Option<(Array1<f64>, bool)>> =
if n >= ENCODE_BATCH_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none() {
use rayon::prelude::*;
const CHUNK: usize = 256;
let n_chunks = n.div_ceil(CHUNK);
let chunked: Vec<Vec<Option<(Array1<f64>, bool)>>> = (0..n_chunks)
.into_par_iter()
.map(|c| {
let start = c * CHUNK;
let end = (start + CHUNK).min(n);
encode_rows(start..end)
})
.collect::<Result<_, _>>()?;
chunked.into_iter().flatten().collect()
} else {
encode_rows(0..n)?
};
let mut coords = Array2::<f64>::zeros((n, latent_dim));
let mut certified = Vec::with_capacity(n);
for (row, slot) in rows.into_iter().enumerate() {
match slot {
Some((t, cert)) => {
coords.row_mut(row).assign(&t);
certified.push(cert);
}
None => certified.push(false),
}
}
Ok(EncodeResult::from_rows(coords, certified))
}
pub fn amortized_encode_with_index_fast<S: AtomFrameSketch + Sync>(
&self,
atoms: &[SaeManifoldAtom],
index: &SaeCandidateIndex,
sketch: &S,
targets: ArrayView2<'_, f64>,
amplitudes: ArrayView1<'_, f64>,
latent_dim: usize,
) -> Result<(Array2<f64>, Vec<bool>), String> {
let n = targets.nrows();
if amplitudes.len() != n {
return Err(format!(
"amortized_encode_with_index_fast: amplitudes len {} != rows {n}",
amplitudes.len()
));
}
let budget = auto_candidate_budget(atoms.len().max(1));
let mut groups: std::collections::HashMap<usize, Vec<usize>> =
std::collections::HashMap::new();
for row in 0..n {
let proposal = index.propose(sketch, targets.row(row), budget, true);
let Some(&best_atom) = proposal.proposed.first() else {
continue;
};
let routing_alignment = sketch.alignment(best_atom, targets.row(row));
if !routing_alignment.is_finite()
|| routing_alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
{
continue;
}
groups.entry(best_atom).or_default().push(row);
}
let mut coords = Array2::<f64>::zeros((n, latent_dim));
let mut valid = vec![false; n];
for (atom_idx, rows_here) in groups {
let atom = atoms.get(atom_idx).ok_or_else(|| {
format!("amortized_encode_with_index_fast: proposed atom {atom_idx} out of range")
})?;
if atom.latent_dim != latent_dim {
return Err(format!(
"amortized_encode_with_index_fast: atom {atom_idx} latent_dim {} != declared \
{latent_dim}; heterogeneous-dim dictionaries are not supported by this path",
atom.latent_dim
));
}
let p = atom.output_dim();
let mut x_sub = Array2::<f64>::zeros((rows_here.len(), p));
let mut amp_sub = Array1::<f64>::zeros(rows_here.len());
for (i, &row) in rows_here.iter().enumerate() {
x_sub.row_mut(i).assign(&targets.row(row));
amp_sub[i] = amplitudes[row];
}
let (sub_coords, sub_valid) =
self.amortized_encode_batch_fast(atom, atom_idx, x_sub.view(), amp_sub.view())?;
for (i, &row) in rows_here.iter().enumerate() {
if sub_valid[i] {
coords.row_mut(row).assign(&sub_coords.row(i));
valid[row] = true;
}
}
}
Ok((coords, valid))
}
pub fn amortized_reconstruct_with_index_fast<S: AtomFrameSketch + Sync>(
&self,
atoms: &[SaeManifoldAtom],
index: &SaeCandidateIndex,
sketch: &S,
targets: ArrayView2<'_, f64>,
amplitudes: ArrayView1<'_, f64>,
) -> Result<(Array2<f64>, Vec<bool>), String> {
let n = targets.nrows();
let p = targets.ncols();
if amplitudes.len() != n {
return Err(format!(
"amortized_reconstruct_with_index_fast: amplitudes len {} != rows {n}",
amplitudes.len()
));
}
let budget = auto_candidate_budget(atoms.len().max(1));
let mut groups: std::collections::HashMap<usize, Vec<usize>> =
std::collections::HashMap::new();
for row in 0..n {
let proposal = index.propose(sketch, targets.row(row), budget, true);
let Some(&best_atom) = proposal.proposed.first() else {
continue;
};
let routing_alignment = sketch.alignment(best_atom, targets.row(row));
if !routing_alignment.is_finite()
|| routing_alignment < CANDIDATE_ROUTING_MIN_ALIGNMENT
{
continue;
}
groups.entry(best_atom).or_default().push(row);
}
let mut recon = Array2::<f64>::zeros((n, p));
let mut valid = vec![false; n];
for (atom_idx, rows_here) in groups {
let atom = atoms.get(atom_idx).ok_or_else(|| {
format!(
"amortized_reconstruct_with_index_fast: proposed atom {atom_idx} out of range"
)
})?;
if atom.output_dim() != p {
return Err(format!(
"amortized_reconstruct_with_index_fast: atom {atom_idx} output_dim {} != target \
dim {p}",
atom.output_dim()
));
}
let mut x_sub = Array2::<f64>::zeros((rows_here.len(), p));
let mut amp_sub = Array1::<f64>::zeros(rows_here.len());
for (i, &row) in rows_here.iter().enumerate() {
x_sub.row_mut(i).assign(&targets.row(row));
amp_sub[i] = amplitudes[row];
}
let (sub_recon, sub_valid) = self.amortized_reconstruct_batch_fast(
atom,
atom_idx,
x_sub.view(),
amp_sub.view(),
)?;
for (i, &row) in rows_here.iter().enumerate() {
if sub_valid[i] {
recon.row_mut(row).assign(&sub_recon.row(i));
valid[row] = true;
}
}
}
Ok((recon, valid))
}
}
pub(crate) fn center_beta(atom: &SaeManifoldAtom, center: &Array1<f64>, ridge: f64) -> Option<f64> {
let evaluator = atom.basis_evaluator.as_ref()?.clone();
let d = atom.latent_dim;
let p = atom.output_dim();
let m = atom.basis_size();
let coords = center.view().to_shape((1, d)).ok()?.to_owned();
let (_phi, jet) = evaluator.evaluate(coords.view()).ok()?;
let decoder = &atom.decoder_coefficients;
let mut jm = Array2::<f64>::zeros((d, p));
for axis in 0..d {
for basis_col in 0..m {
let dphi = jet[[0, basis_col, axis]];
if dphi == 0.0 {
continue;
}
for out in 0..p {
jm[[axis, out]] += dphi * decoder[[basis_col, out]];
}
}
}
let mut h = Array2::<f64>::zeros((d, d));
for a in 0..d {
for b in 0..d {
h[[a, b]] = jm.row(a).dot(&jm.row(b));
}
h[[a, a]] += ridge;
}
let (vals, _vecs) = h.eigh(Side::Lower).ok()?;
let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
if lambda_min.is_finite() && lambda_min > 0.0 {
Some(1.0 / lambda_min)
} else {
None
}
}
pub(crate) fn amortized_warm_start(
chart: &CertifiedChart,
x: ArrayView1<'_, f64>,
amplitude: f64,
) -> Option<Array1<f64>> {
let a1 = chart.amortized_jacobian.as_ref()?;
if !(amplitude.is_finite() && amplitude.abs() > 0.0) {
return None;
}
let d = a1.nrows();
let mut t_hat = chart.region.center.clone();
for (out_idx, &m1_out) in chart.recon_center.iter().enumerate().take(a1.ncols()) {
let resid = x[out_idx] - amplitude * m1_out;
for axis in 0..d {
t_hat[axis] += a1[[axis, out_idx]] * resid / amplitude;
}
}
Some(t_hat)
}
pub(crate) fn center_amortized_jacobian(
atom: &SaeManifoldAtom,
center: &Array1<f64>,
ridge: f64,
) -> Option<(Array2<f64>, Array1<f64>)> {
let evaluator = atom.basis_evaluator.as_ref()?.clone();
let d = atom.latent_dim;
let p = atom.output_dim();
let m = atom.basis_size();
let coords = center.view().to_shape((1, d)).ok()?.to_owned();
let (phi, jet) = evaluator.evaluate(coords.view()).ok()?;
let decoder = &atom.decoder_coefficients;
let mut recon = Array1::<f64>::zeros(p);
for basis_col in 0..m {
let phi_v = phi[[0, basis_col]];
if phi_v == 0.0 {
continue;
}
for out in 0..p {
recon[out] += phi_v * decoder[[basis_col, out]];
}
}
let mut jm = Array2::<f64>::zeros((d, p));
for axis in 0..d {
for basis_col in 0..m {
let dphi = jet[[0, basis_col, axis]];
if dphi == 0.0 {
continue;
}
for out in 0..p {
jm[[axis, out]] += dphi * decoder[[basis_col, out]];
}
}
}
let mut h = Array2::<f64>::zeros((d, d));
for a in 0..d {
for b in 0..d {
h[[a, b]] = jm.row(a).dot(&jm.row(b));
}
h[[a, a]] += ridge;
}
let (vals, vecs) = h.eigh(Side::Lower).ok()?;
let lambda_min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
if !(lambda_min.is_finite() && lambda_min > 0.0) {
return None;
}
let mut a1 = Array2::<f64>::zeros((d, p));
for out in 0..p {
let jcol = jm.column(out);
for (i, &lam) in vals.iter().enumerate() {
if !(lam.is_finite() && lam > 0.0) {
return None;
}
let vi = vecs.column(i);
let coeff = vi.dot(&jcol) / lam;
for row in 0..d {
a1[[row, out]] += coeff * vi[row];
}
}
}
Some((a1, recon))
}
pub(crate) fn nearest_chart(
atom_atlas: &AtomEncodeAtlas,
x: ArrayView1<'_, f64>,
atom: &SaeManifoldAtom,
evaluator: &dyn SaeBasisEvaluator,
) -> Option<(usize, f64)> {
if atom_atlas.charts.is_empty() {
return None;
}
let d = atom.latent_dim;
let p = atom.output_dim();
let m = atom.basis_size();
let mut best: Option<(usize, f64)> = None;
for (idx, chart) in atom_atlas.charts.iter().enumerate() {
if chart.certified_radius <= 0.0 {
continue;
}
let coords = match chart.region.center.view().to_shape((1, d)) {
Ok(c) => c.to_owned(),
Err(_) => continue,
};
let Ok((phi, _jet)) = evaluator.evaluate(coords.view()) else {
continue;
};
let mut recon = Array1::<f64>::zeros(p);
for basis_col in 0..m {
let phi_v = phi[[0, basis_col]];
if phi_v == 0.0 {
continue;
}
for out in 0..p {
recon[out] += phi_v * atom.decoder_coefficients[[basis_col, out]];
}
}
let diff = &recon - &x;
let dist = diff.dot(&diff);
if best.map(|(_, b)| dist < b).unwrap_or(true) {
best = Some((idx, dist));
}
}
best
}
pub(crate) fn nearest_charts_topk(
atom_atlas: &AtomEncodeAtlas,
x: ArrayView1<'_, f64>,
atom: &SaeManifoldAtom,
evaluator: &dyn SaeBasisEvaluator,
k: usize,
) -> Vec<usize> {
if atom_atlas.charts.is_empty() || k == 0 {
return Vec::new();
}
let d = atom.latent_dim;
let p = atom.output_dim();
let m = atom.basis_size();
let mut scored: Vec<(usize, f64)> = Vec::new();
for (idx, chart) in atom_atlas.charts.iter().enumerate() {
if chart.certified_radius <= 0.0 {
continue;
}
let coords = match chart.region.center.view().to_shape((1, d)) {
Ok(c) => c.to_owned(),
Err(_) => continue,
};
let Ok((phi, _jet)) = evaluator.evaluate(coords.view()) else {
continue;
};
let mut recon = Array1::<f64>::zeros(p);
for basis_col in 0..m {
let phi_v = phi[[0, basis_col]];
if phi_v == 0.0 {
continue;
}
for out in 0..p {
recon[out] += phi_v * atom.decoder_coefficients[[basis_col, out]];
}
}
let diff = &recon - &x;
scored.push((idx, diff.dot(&diff)));
}
scored.sort_by(|a, b| {
a.1.partial_cmp(&b.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.0.cmp(&b.0))
});
scored.into_iter().take(k).map(|(idx, _)| idx).collect()
}
pub(crate) fn encode_reconstruction_error(
atom: &SaeManifoldAtom,
evaluator: &dyn SaeBasisEvaluator,
coord: ArrayView1<'_, f64>,
x: ArrayView1<'_, f64>,
amplitude: f64,
) -> f64 {
let d = atom.latent_dim;
let p = atom.output_dim();
let m = atom.basis_size();
let coords = match coord.to_shape((1, d)) {
Ok(c) => c.to_owned(),
Err(_) => return f64::INFINITY,
};
let Ok((phi, _jet)) = evaluator.evaluate(coords.view()) else {
return f64::INFINITY;
};
let mut err2 = 0.0;
for out in 0..p {
let mut recon = 0.0;
for basis_col in 0..m {
recon += phi[[0, basis_col]] * atom.decoder_coefficients[[basis_col, out]];
}
let r = x[out] - amplitude * recon;
err2 += r * r;
}
if err2.is_finite() { err2.sqrt() } else { f64::INFINITY }
}
pub(crate) const SHAPE_BAND_MAX_POINTS: usize = 512;
pub(crate) fn coord_dist_sq(atom: &SaeManifoldAtom, a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> f64 {
use crate::manifold::SaeAtomBasisKind::*;
let periodic_axis = |axis: usize| -> bool {
match &atom.basis_kind {
Periodic | Torus | Sphere => true,
Cylinder => axis == 0,
Linear | Duchon | EuclideanPatch | Poincare | Precomputed(_) => false,
}
};
let mut acc = 0.0;
for axis in 0..a.len() {
let mut d = (a[axis] - b[axis]).abs();
if periodic_axis(axis) {
d -= d.floor(); d = d.min(1.0 - d);
}
acc += d * d;
}
acc
}
pub(crate) fn data_driven_chart_centers(
atom: &SaeManifoldAtom,
coords: ArrayView2<'_, f64>,
max_charts: usize,
) -> Result<(Array2<f64>, Vec<f64>), String> {
let n = coords.nrows();
let d = coords.ncols();
if d != atom.latent_dim {
return Err(format!(
"data_driven_chart_centers: coords have {d} cols but atom latent_dim is {}",
atom.latent_dim
));
}
if n == 0 {
return Ok((Array2::<f64>::zeros((0, d)), Vec::new()));
}
let k = max_charts.min(n);
let mut chosen: Vec<usize> = Vec::with_capacity(k);
chosen.push(0);
let mut nearest_sq: Vec<f64> = (0..n)
.map(|r| coord_dist_sq(atom, coords.row(r), coords.row(0)))
.collect();
while chosen.len() < k {
let mut best = 0usize;
let mut best_d = -1.0;
for r in 0..n {
if nearest_sq[r] > best_d {
best_d = nearest_sq[r];
best = r;
}
}
if best_d <= 0.0 {
break; }
chosen.push(best);
for r in 0..n {
let dr = coord_dist_sq(atom, coords.row(r), coords.row(best));
if dr < nearest_sq[r] {
nearest_sq[r] = dr;
}
}
}
let m = chosen.len();
let mut centers = Array2::<f64>::zeros((m, d));
for (i, &row) in chosen.iter().enumerate() {
centers.row_mut(i).assign(&coords.row(row));
}
let mut radii = vec![0.0_f64; m];
for i in 0..m {
let mut nn = f64::INFINITY;
for j in 0..m {
if i == j {
continue;
}
let dsq = coord_dist_sq(atom, centers.row(i), centers.row(j));
if dsq < nn {
nn = dsq;
}
}
let r = if nn.is_finite() { 0.5 * nn.sqrt() } else { 0.5 };
radii[i] = r.max(1.0e-3).min(0.5);
}
Ok((centers, radii))
}
pub(crate) fn chart_center_grid(atom: &SaeManifoldAtom, resolution: usize) -> Array2<f64> {
use crate::manifold::SaeAtomBasisKind::*;
let d = atom.latent_dim;
match &atom.basis_kind {
Periodic | Torus => regular_product_grid(d, resolution, 0.0, 1.0, false),
Cylinder if d == 2 => cylinder_chart_center_grid(resolution),
Cylinder => regular_product_grid(d, resolution, -0.5, 0.5, true),
Sphere if d == 2 => sphere_latlon_grid(resolution),
Linear | Sphere | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
regular_product_grid(d, resolution, -0.5, 0.5, true)
}
}
}
pub(crate) fn regular_product_grid(
d: usize,
resolution: usize,
lo: f64,
hi: f64,
include_endpoint: bool,
) -> Array2<f64> {
if d == 0 {
return Array2::<f64>::zeros((1, 0));
}
let mut per_axis = resolution.max(2);
while per_axis.saturating_pow(d as u32) > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
per_axis -= 1;
}
let total = per_axis.saturating_pow(d as u32).max(1);
let denom = if include_endpoint {
(per_axis.max(2) - 1) as f64
} else {
per_axis as f64
};
let mut grid = Array2::<f64>::zeros((total, d));
let mut idx = vec![0usize; d];
for flat in 0..total {
for axis in 0..d {
let frac = idx[axis] as f64 / denom;
grid[[flat, axis]] = lo + (hi - lo) * frac;
}
for axis in (0..d).rev() {
idx[axis] += 1;
if idx[axis] < per_axis {
break;
}
idx[axis] = 0;
}
}
grid
}
pub(crate) fn sphere_latlon_grid(resolution: usize) -> Array2<f64> {
use std::f64::consts::PI;
let r = resolution.max(2).min(22); let mut grid = Array2::<f64>::zeros((r * r, 2));
for i in 0..r {
let lat = -PI / 2.0 + PI * (i as f64 + 0.5) / r as f64;
for j in 0..r {
let lon = -PI + 2.0 * PI * (j as f64) / r as f64;
grid[[i * r + j, 0]] = lat;
grid[[i * r + j, 1]] = lon;
}
}
grid
}
pub(crate) fn cylinder_chart_center_grid(resolution: usize) -> Array2<f64> {
let mut per_axis = resolution.max(2);
while per_axis * per_axis > SHAPE_BAND_MAX_POINTS && per_axis > 2 {
per_axis -= 1;
}
let total = per_axis * per_axis;
let line_denom = (per_axis.max(2) - 1) as f64;
let mut grid = Array2::<f64>::zeros((total, 2));
for i in 0..per_axis {
let circle = i as f64 / per_axis as f64;
for j in 0..per_axis {
let line = -0.5 + (j as f64) / line_denom;
grid[[i * per_axis + j, 0]] = circle;
grid[[i * per_axis + j, 1]] = line;
}
}
grid
}
pub(crate) fn chart_nominal_radius(atom: &SaeManifoldAtom, resolution: usize) -> f64 {
use crate::manifold::SaeAtomBasisKind::*;
match &atom.basis_kind {
Periodic | Torus => 0.5 / (resolution.max(2) as f64),
Sphere => std::f64::consts::PI / (resolution.max(2) as f64),
Cylinder => 0.5 / (resolution.max(2) as f64),
Linear | Duchon | EuclideanPatch | Poincare | Precomputed(_) => {
1.0 / (resolution.max(2) as f64)
}
}
}
pub(crate) fn chart_region(
atom: &SaeManifoldAtom,
center: Array1<f64>,
radius: f64,
) -> ChartRegion {
use crate::manifold::SaeAtomBasisKind::*;
let region = ChartRegion::new(center.clone(), radius);
match &atom.basis_kind {
Duchon => {
let center_norm = center.dot(¢er).sqrt();
let r_min = (center_norm - radius).max(f64::MIN_POSITIVE);
let r_max = center_norm + radius;
region.with_radial_bounds(r_min, r_max)
}
Periodic | Sphere | Torus | Cylinder | Linear | EuclideanPatch | Poincare
| Precomputed(_) => region,
}
}