use ndarray::{Array2, Array3, Array4, Array5, ArrayView2};
use std::sync::Arc;
pub trait SaeBasisEvaluator: Send + Sync + std::fmt::Debug {
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String>;
fn affine_transformed_evaluator(
&self,
shift: &[f64],
scale: &[f64],
n_basis: usize,
) -> Result<Option<Arc<dyn SaeBasisSecondJet>>, String> {
if shift.len() == usize::MAX || scale.len() == usize::MAX || n_basis == usize::MAX {
return Err("SaeBasisEvaluator::affine_transformed_evaluator: unreachable affine metadata width".to_string());
}
Ok(None)
}
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
Ok(PhiEtaSplit::all_linear(n_basis))
}
fn factor_basis_sizes(&self) -> Option<(usize, usize)> {
None
}
fn evaluate_phi_eta(
&self,
coords: ArrayView2<'_, f64>,
eta: f64,
) -> Result<PhiEtaEvaluation, String> {
if !(eta.is_finite() && (0.0..=1.0).contains(&eta)) {
return Err(format!(
"SaeBasisEvaluator::evaluate_phi_eta: eta must be finite in [0, 1]; got {eta}"
));
}
let (mut phi, mut jet) = self.evaluate(coords)?;
let split = self.phi_eta_split(phi.ncols())?;
let mut dphi_deta = Array2::<f64>::zeros(phi.dim());
let mut djet_deta = Array3::<f64>::zeros(jet.dim());
for &col in &split.curved_cols {
if col >= phi.ncols() {
return Err(format!(
"SaeBasisEvaluator::evaluate_phi_eta: curved column {col} exceeds basis width {}",
phi.ncols()
));
}
for row in 0..phi.nrows() {
dphi_deta[[row, col]] = phi[[row, col]];
if eta != 1.0 {
phi[[row, col]] *= eta;
}
for axis in 0..jet.shape()[2] {
djet_deta[[row, col, axis]] = jet[[row, col, axis]];
if eta != 1.0 {
jet[[row, col, axis]] *= eta;
}
}
}
}
Ok(PhiEtaEvaluation {
phi,
jet,
dphi_deta,
djet_deta,
split,
})
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>>;
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>>;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PhiEtaSplit {
pub linear_cols: Vec<usize>,
pub curved_cols: Vec<usize>,
}
impl PhiEtaSplit {
pub fn all_linear(n_basis: usize) -> Self {
Self {
linear_cols: (0..n_basis).collect(),
curved_cols: Vec::new(),
}
}
fn from_curved_mask(mask: Vec<bool>) -> Self {
let mut linear_cols = Vec::new();
let mut curved_cols = Vec::new();
for (col, curved) in mask.into_iter().enumerate() {
if curved {
curved_cols.push(col);
} else {
linear_cols.push(col);
}
}
Self {
linear_cols,
curved_cols,
}
}
}
#[derive(Debug, Clone)]
pub struct PhiEtaEvaluation {
pub phi: Array2<f64>,
pub jet: Array3<f64>,
pub dphi_deta: Array2<f64>,
pub djet_deta: Array3<f64>,
pub split: PhiEtaSplit,
}
fn monomial_linear_mask(dimension: usize, max_total_degree: usize) -> Vec<bool> {
crate::basis::monomial_exponents(dimension, max_total_degree)
.iter()
.map(|alpha| alpha.iter().sum::<usize>() <= 1)
.collect()
}
fn duchon_effective_order_for_eta(
centers: ArrayView2<'_, f64>,
order: crate::basis::DuchonNullspaceOrder,
) -> crate::basis::DuchonNullspaceOrder {
let mut effective = order;
while effective != crate::basis::DuchonNullspaceOrder::Zero
&& centers.nrows() <= duchon_polynomial_column_count(centers.ncols(), effective)
{
effective = match effective {
crate::basis::DuchonNullspaceOrder::Zero => crate::basis::DuchonNullspaceOrder::Zero,
crate::basis::DuchonNullspaceOrder::Linear => crate::basis::DuchonNullspaceOrder::Zero,
crate::basis::DuchonNullspaceOrder::Degree(2) => {
crate::basis::DuchonNullspaceOrder::Linear
}
crate::basis::DuchonNullspaceOrder::Degree(k) => {
crate::basis::DuchonNullspaceOrder::Degree(k - 1)
}
};
}
effective
}
fn duchon_polynomial_column_count(
dimension: usize,
order: crate::basis::DuchonNullspaceOrder,
) -> usize {
match order {
crate::basis::DuchonNullspaceOrder::Zero => 1,
crate::basis::DuchonNullspaceOrder::Linear => dimension + 1,
crate::basis::DuchonNullspaceOrder::Degree(degree) => {
crate::basis::monomial_exponents(dimension, degree).len()
}
}
}
pub trait SaeBasisSecondJet: SaeBasisEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String>;
}
pub trait SaeBasisThirdJet: SaeBasisSecondJet {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String>;
}
#[derive(Debug, Clone)]
pub struct PeriodicHarmonicEvaluator {
pub num_basis: usize,
}
impl PeriodicHarmonicEvaluator {
pub fn new(num_basis: usize) -> Result<Self, String> {
if num_basis == 0 || num_basis % 2 == 0 {
return Err(format!(
"PeriodicHarmonicEvaluator requires odd num_basis >= 1; got {num_basis}"
));
}
Ok(Self { num_basis })
}
}
impl SaeBasisEvaluator for PeriodicHarmonicEvaluator {
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
if n_basis != self.num_basis {
return Err(format!(
"PeriodicHarmonicEvaluator::phi_eta_split: n_basis {n_basis} != evaluator width {}",
self.num_basis
));
}
let mut curved = vec![false; n_basis];
for h in 2..=(n_basis - 1) / 2 {
curved[2 * h - 1] = true;
curved[2 * h] = true;
}
Ok(PhiEtaSplit::from_curved_mask(curved))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
Some(<Self as SaeBasisThirdJet>::third_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
let n = coords.nrows();
let d = coords.ncols();
if d != 1 {
return Err(format!(
"PeriodicHarmonicEvaluator: expected latent_dim == 1, got {d}"
));
}
let m = self.num_basis;
let num_harmonics = (m - 1) / 2;
let two_pi = 2.0 * std::f64::consts::PI;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, 1));
for row in 0..n {
let t = coords[[row, 0]];
phi[[row, 0]] = 1.0;
for h in 1..=num_harmonics {
let angle = two_pi * (h as f64) * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * h - 1;
let c_idx = 2 * h;
phi[[row, s_idx]] = s;
phi[[row, c_idx]] = c;
jet[[row, s_idx, 0]] = two_pi * (h as f64) * c;
jet[[row, c_idx, 0]] = -two_pi * (h as f64) * s;
}
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for PeriodicHarmonicEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
let n = coords.nrows();
let d = coords.ncols();
if d != 1 {
return Err(format!(
"PeriodicHarmonicEvaluator::second_jet: expected latent_dim == 1, got {d}"
));
}
let m = self.num_basis;
let num_harmonics = (m - 1) / 2;
let two_pi = 2.0 * std::f64::consts::PI;
let mut h = Array4::<f64>::zeros((n, m, 1, 1));
for row in 0..n {
let t = coords[[row, 0]];
for k in 1..=num_harmonics {
let freq = two_pi * (k as f64);
let freq2 = freq * freq;
let angle = freq * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * k - 1;
let c_idx = 2 * k;
h[[row, s_idx, 0, 0]] = -freq2 * s;
h[[row, c_idx, 0, 0]] = -freq2 * c;
}
}
Ok(h)
}
}
impl SaeBasisThirdJet for PeriodicHarmonicEvaluator {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String> {
let n = coords.nrows();
let d = coords.ncols();
if d != 1 {
return Err(format!(
"PeriodicHarmonicEvaluator::third_jet: expected latent_dim == 1, got {d}"
));
}
let m = self.num_basis;
let num_harmonics = (m - 1) / 2;
let two_pi = 2.0 * std::f64::consts::PI;
let mut t3 = Array5::<f64>::zeros((n, m, 1, 1, 1));
for row in 0..n {
let t = coords[[row, 0]];
for k in 1..=num_harmonics {
let freq = two_pi * (k as f64);
let freq3 = freq * freq * freq;
let angle = freq * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * k - 1;
let c_idx = 2 * k;
t3[[row, s_idx, 0, 0, 0]] = -freq3 * c;
t3[[row, c_idx, 0, 0, 0]] = freq3 * s;
}
}
Ok(t3)
}
}
#[derive(Debug, Clone)]
pub struct RawPeriodicCircleEvaluator {
pub latent_dim: usize,
}
impl RawPeriodicCircleEvaluator {
pub fn new(latent_dim: usize) -> Result<Self, String> {
if latent_dim == 0 {
return Err("RawPeriodicCircleEvaluator requires latent_dim >= 1".to_string());
}
Ok(Self { latent_dim })
}
}
impl SaeBasisEvaluator for RawPeriodicCircleEvaluator {
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
if n_basis != 2 {
return Err(format!(
"RawPeriodicCircleEvaluator::phi_eta_split: n_basis {n_basis} != 2"
));
}
Ok(PhiEtaSplit::all_linear(n_basis))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
if coords.ncols() != self.latent_dim {
return Some(Err(format!(
"RawPeriodicCircleEvaluator::second_jet_dyn: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
)));
}
None
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
if coords.ncols() != self.latent_dim {
return Some(Err(format!(
"RawPeriodicCircleEvaluator::third_jet_dyn: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
)));
}
None
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"RawPeriodicCircleEvaluator: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 2));
let mut jet = Array3::<f64>::zeros((n, 2, self.latent_dim));
for row in 0..n {
let t = coords[[row, 0]];
phi[[row, 0]] = t.cos();
phi[[row, 1]] = t.sin();
jet[[row, 0, 0]] = -t.sin();
jet[[row, 1, 0]] = t.cos();
}
Ok((phi, jet))
}
}
pub const SPHERE_CHART_PENALTY_DIAGONAL: [f64; 7] = [1e-8, 1.0, 1.0, 1.0, 4.0, 4.0, 4.0];
pub fn sphere_chart_basis_jet(
coords: ArrayView2<'_, f64>,
) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != 2 {
return Err(format!(
"sphere_chart_basis_jet expects latent_dim == 2, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 7));
let mut jet = Array3::<f64>::zeros((n, 7, 2));
for row in 0..n {
let lat = coords[[row, 0]];
let lon = coords[[row, 1]];
let clat = lat.cos();
let slat = lat.sin();
let clon = lon.cos();
let slon = lon.sin();
let x = clat * clon;
let y = clat * slon;
let z = slat;
phi[[row, 0]] = 1.0;
phi[[row, 1]] = x;
phi[[row, 2]] = y;
phi[[row, 3]] = z;
phi[[row, 4]] = x * y;
phi[[row, 5]] = y * z;
phi[[row, 6]] = x * z;
let dx_dlat = -slat * clon;
let dx_dlon = -clat * slon;
let dy_dlat = -slat * slon;
let dy_dlon = clat * clon;
let dz_dlat = clat;
jet[[row, 1, 0]] = dx_dlat;
jet[[row, 1, 1]] = dx_dlon;
jet[[row, 2, 0]] = dy_dlat;
jet[[row, 2, 1]] = dy_dlon;
jet[[row, 3, 0]] = dz_dlat;
jet[[row, 4, 0]] = dx_dlat * y + x * dy_dlat;
jet[[row, 4, 1]] = dx_dlon * y + x * dy_dlon;
jet[[row, 5, 0]] = dy_dlat * z + y * dz_dlat;
jet[[row, 5, 1]] = dy_dlon * z;
jet[[row, 6, 0]] = dx_dlat * z + x * dz_dlat;
jet[[row, 6, 1]] = dx_dlon * z;
}
Ok((phi, jet))
}
#[derive(Debug, Clone)]
pub struct SphereChartEvaluator;
impl SaeBasisEvaluator for SphereChartEvaluator {
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
if n_basis != 7 {
return Err(format!(
"SphereChartEvaluator::phi_eta_split: n_basis {n_basis} != 7"
));
}
let mut curved = vec![false; n_basis];
for col in 4..7 {
curved[col] = true;
}
Ok(PhiEtaSplit::from_curved_mask(curved))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
Some(<Self as SaeBasisThirdJet>::third_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
sphere_chart_basis_jet(coords)
}
}
impl SaeBasisSecondJet for SphereChartEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
if coords.ncols() != 2 {
return Err(format!(
"SphereChartEvaluator::second_jet expects latent_dim == 2, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let mut h = Array4::<f64>::zeros((n, 7, 2, 2));
for row in 0..n {
let lat = coords[[row, 0]];
let lon = coords[[row, 1]];
let clat = lat.cos();
let slat = lat.sin();
let clon = lon.cos();
let slon = lon.sin();
let x = clat * clon;
let y = clat * slon;
let z = slat;
let dx = [-slat * clon, -clat * slon];
let dy = [-slat * slon, clat * clon];
let dz = [clat, 0.0];
let hx = [[-x, slat * slon], [slat * slon, -x]];
let hy = [[-y, -slat * clon], [-slat * clon, -y]];
let hz = [[-z, 0.0], [0.0, 0.0]];
for axis_a in 0..2 {
for axis_b in 0..2 {
h[[row, 1, axis_a, axis_b]] = hx[axis_a][axis_b];
h[[row, 2, axis_a, axis_b]] = hy[axis_a][axis_b];
h[[row, 3, axis_a, axis_b]] = hz[axis_a][axis_b];
}
}
let pair = |hf: [[f64; 2]; 2],
df: [f64; 2],
f: f64,
hg: [[f64; 2]; 2],
dg: [f64; 2],
g: f64|
-> [[f64; 2]; 2] {
let mut out = [[0.0; 2]; 2];
for axis_a in 0..2 {
for axis_b in 0..2 {
out[axis_a][axis_b] = hf[axis_a][axis_b] * g
+ df[axis_a] * dg[axis_b]
+ df[axis_b] * dg[axis_a]
+ f * hg[axis_a][axis_b];
}
}
out
};
let hxy = pair(hx, dx, x, hy, dy, y);
let hyz = pair(hy, dy, y, hz, dz, z);
let hxz = pair(hx, dx, x, hz, dz, z);
for axis_a in 0..2 {
for axis_b in 0..2 {
h[[row, 4, axis_a, axis_b]] = hxy[axis_a][axis_b];
h[[row, 5, axis_a, axis_b]] = hyz[axis_a][axis_b];
h[[row, 6, axis_a, axis_b]] = hxz[axis_a][axis_b];
}
}
}
Ok(h)
}
}
impl SaeBasisThirdJet for SphereChartEvaluator {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String> {
if coords.ncols() != 2 {
return Err(format!(
"SphereChartEvaluator::third_jet expects latent_dim == 2, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let mut t3 = Array5::<f64>::zeros((n, 7, 2, 2, 2));
let single = |lat: &[f64; 4], lon: &[f64; 4], ax: [usize; 3]| -> f64 {
let n_lat = ax.iter().filter(|&&q| q == 0).count();
lat[n_lat] * lon[3 - n_lat]
};
let product = |f_lat: &[f64; 4],
f_lon: &[f64; 4],
g_lat: &[f64; 4],
g_lon: &[f64; 4],
ax: [usize; 3]|
-> f64 {
let mut acc = 0.0;
for mask in 0u8..8 {
let (mut f_lat_n, mut f_lon_n, mut g_lat_n, mut g_lon_n) = (0, 0, 0, 0);
for (i, &axis) in ax.iter().enumerate() {
let to_f = (mask >> i) & 1 == 1;
match (to_f, axis == 0) {
(true, true) => f_lat_n += 1,
(true, false) => f_lon_n += 1,
(false, true) => g_lat_n += 1,
(false, false) => g_lon_n += 1,
}
}
acc += f_lat[f_lat_n] * f_lon[f_lon_n] * g_lat[g_lat_n] * g_lon[g_lon_n];
}
acc
};
for row in 0..n {
let lat = coords[[row, 0]];
let lon = coords[[row, 1]];
let clat = lat.cos();
let slat = lat.sin();
let clon = lon.cos();
let slon = lon.sin();
let cos_lat = [clat, -slat, -clat, slat];
let sin_lat = [slat, clat, -slat, -clat];
let cos_lon = [clon, -slon, -clon, slon];
let sin_lon = [slon, clon, -slon, -clon];
let const_lon = [1.0, 0.0, 0.0, 0.0];
let (x_lat, x_lon) = (&cos_lat, &cos_lon);
let (y_lat, y_lon) = (&cos_lat, &sin_lon);
let (z_lat, z_lon) = (&sin_lat, &const_lon);
for axis_a in 0..2 {
for axis_b in 0..2 {
for axis_c in 0..2 {
let ax = [axis_a, axis_b, axis_c];
t3[[row, 1, axis_a, axis_b, axis_c]] = single(x_lat, x_lon, ax);
t3[[row, 2, axis_a, axis_b, axis_c]] = single(y_lat, y_lon, ax);
t3[[row, 3, axis_a, axis_b, axis_c]] = single(z_lat, z_lon, ax);
t3[[row, 4, axis_a, axis_b, axis_c]] =
product(x_lat, x_lon, y_lat, y_lon, ax);
t3[[row, 5, axis_a, axis_b, axis_c]] =
product(y_lat, y_lon, z_lat, z_lon, ax);
t3[[row, 6, axis_a, axis_b, axis_c]] =
product(x_lat, x_lon, z_lat, z_lon, ax);
}
}
}
}
Ok(t3)
}
}
#[derive(Debug, Clone)]
pub struct TorusHarmonicEvaluator {
pub latent_dim: usize,
pub num_harmonics: usize,
}
impl TorusHarmonicEvaluator {
pub fn new(latent_dim: usize, num_harmonics: usize) -> Result<Self, String> {
if latent_dim == 0 {
return Err("TorusHarmonicEvaluator requires latent_dim >= 1".to_string());
}
if num_harmonics == 0 {
return Err("TorusHarmonicEvaluator requires num_harmonics >= 1".to_string());
}
Ok(Self {
latent_dim,
num_harmonics,
})
}
pub fn axis_basis_size(&self) -> usize {
2 * self.num_harmonics + 1
}
pub fn basis_size(&self) -> usize {
let axis_m = self.axis_basis_size();
let mut total: usize = 1;
for _ in 0..self.latent_dim {
total = total
.checked_mul(axis_m)
.expect("TorusHarmonicEvaluator: basis size overflowed usize");
}
total
}
}
impl SaeBasisEvaluator for TorusHarmonicEvaluator {
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
let expected = self.basis_size();
if n_basis != expected {
return Err(format!(
"TorusHarmonicEvaluator::phi_eta_split: n_basis {n_basis} != evaluator width {expected}"
));
}
let d = self.latent_dim;
let axis_m = self.axis_basis_size();
let mut curved = Vec::with_capacity(n_basis);
let mut idx = vec![0usize; d];
for _flat in 0..n_basis {
let mut nonconstant_axes = 0usize;
let mut has_higher_harmonic = false;
for &axis_col in &idx {
if axis_col > 0 {
nonconstant_axes += 1;
if axis_col > 2 {
has_higher_harmonic = true;
}
}
}
curved.push(has_higher_harmonic || nonconstant_axes > 1);
for axis in (0..d).rev() {
idx[axis] += 1;
if idx[axis] < axis_m {
break;
}
idx[axis] = 0;
}
}
Ok(PhiEtaSplit::from_curved_mask(curved))
}
fn factor_basis_sizes(&self) -> Option<(usize, usize)> {
if self.latent_dim == 2 {
let m = self.axis_basis_size();
Some((m, m))
} else {
None
}
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
Some(<Self as SaeBasisThirdJet>::third_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
let d = self.latent_dim;
if coords.ncols() != d {
return Err(format!(
"TorusHarmonicEvaluator: expected latent_dim {d}, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let axis_m = self.axis_basis_size();
let m = self.basis_size();
let h_max = self.num_harmonics;
let two_pi = 2.0 * std::f64::consts::PI;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, d));
let mut phi_axis = vec![vec![0.0_f64; axis_m]; d];
let mut dphi_axis = vec![vec![0.0_f64; axis_m]; d];
for row in 0..n {
for axis in 0..d {
let t = coords[[row, axis]];
phi_axis[axis][0] = 1.0;
dphi_axis[axis][0] = 0.0;
for h in 1..=h_max {
let freq = two_pi * (h as f64);
let angle = freq * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * h - 1;
let c_idx = 2 * h;
phi_axis[axis][s_idx] = s;
phi_axis[axis][c_idx] = c;
dphi_axis[axis][s_idx] = freq * c;
dphi_axis[axis][c_idx] = -freq * s;
}
}
let mut idx = vec![0usize; d];
for flat in 0..m {
let mut val = 1.0_f64;
for axis in 0..d {
val *= phi_axis[axis][idx[axis]];
}
phi[[row, flat]] = val;
for axis_target in 0..d {
let mut deriv = 1.0_f64;
for axis in 0..d {
deriv *= if axis == axis_target {
dphi_axis[axis][idx[axis]]
} else {
phi_axis[axis][idx[axis]]
};
}
jet[[row, flat, axis_target]] = deriv;
}
for axis in (0..d).rev() {
idx[axis] += 1;
if idx[axis] < axis_m {
break;
}
idx[axis] = 0;
}
}
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for TorusHarmonicEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
let d = self.latent_dim;
if coords.ncols() != d {
return Err(format!(
"TorusHarmonicEvaluator::second_jet expects latent_dim == {d}, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let axis_m = self.axis_basis_size();
let m = self.basis_size();
let h_max = self.num_harmonics;
let two_pi = 2.0 * std::f64::consts::PI;
let mut hess = Array4::<f64>::zeros((n, m, d, d));
let mut phi_axis = vec![vec![0.0_f64; axis_m]; d];
let mut dphi_axis = vec![vec![0.0_f64; axis_m]; d];
let mut d2phi_axis = vec![vec![0.0_f64; axis_m]; d];
for row in 0..n {
for axis in 0..d {
let t = coords[[row, axis]];
phi_axis[axis][0] = 1.0;
dphi_axis[axis][0] = 0.0;
d2phi_axis[axis][0] = 0.0;
for k in 1..=h_max {
let freq = two_pi * (k as f64);
let freq2 = freq * freq;
let angle = freq * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * k - 1;
let c_idx = 2 * k;
phi_axis[axis][s_idx] = s;
phi_axis[axis][c_idx] = c;
dphi_axis[axis][s_idx] = freq * c;
dphi_axis[axis][c_idx] = -freq * s;
d2phi_axis[axis][s_idx] = -freq2 * s;
d2phi_axis[axis][c_idx] = -freq2 * c;
}
}
let mut idx = vec![0usize; d];
for flat in 0..m {
for axis_a in 0..d {
for axis_b in 0..d {
let mut prod = 1.0_f64;
for axis in 0..d {
let factor = if axis == axis_a && axis == axis_b {
d2phi_axis[axis][idx[axis]]
} else if axis == axis_a || axis == axis_b {
dphi_axis[axis][idx[axis]]
} else {
phi_axis[axis][idx[axis]]
};
prod *= factor;
}
hess[[row, flat, axis_a, axis_b]] = prod;
}
}
for axis in (0..d).rev() {
idx[axis] += 1;
if idx[axis] < axis_m {
break;
}
idx[axis] = 0;
}
}
}
Ok(hess)
}
}
impl SaeBasisThirdJet for TorusHarmonicEvaluator {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String> {
let d = self.latent_dim;
if coords.ncols() != d {
return Err(format!(
"TorusHarmonicEvaluator::third_jet expects latent_dim == {d}, got {}",
coords.ncols()
));
}
let n = coords.nrows();
let axis_m = self.axis_basis_size();
let m = self.basis_size();
let h_max = self.num_harmonics;
let two_pi = 2.0 * std::f64::consts::PI;
let mut t3 = Array5::<f64>::zeros((n, m, d, d, d));
let mut deriv_axis = vec![vec![vec![0.0_f64; axis_m]; 4]; d];
for row in 0..n {
for axis in 0..d {
let t = coords[[row, axis]];
for order in 0..4 {
deriv_axis[axis][order][0] = 0.0;
}
deriv_axis[axis][0][0] = 1.0;
for k in 1..=h_max {
let freq = two_pi * (k as f64);
let freq2 = freq * freq;
let freq3 = freq2 * freq;
let angle = freq * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * k - 1;
let c_idx = 2 * k;
deriv_axis[axis][0][s_idx] = s;
deriv_axis[axis][0][c_idx] = c;
deriv_axis[axis][1][s_idx] = freq * c;
deriv_axis[axis][1][c_idx] = -freq * s;
deriv_axis[axis][2][s_idx] = -freq2 * s;
deriv_axis[axis][2][c_idx] = -freq2 * c;
deriv_axis[axis][3][s_idx] = -freq3 * c;
deriv_axis[axis][3][c_idx] = freq3 * s;
}
}
let mut idx = vec![0usize; d];
for flat in 0..m {
for axis_a in 0..d {
for axis_b in 0..d {
for axis_c in 0..d {
let mut prod = 1.0_f64;
for axis in 0..d {
let order = (axis == axis_a) as usize
+ (axis == axis_b) as usize
+ (axis == axis_c) as usize;
prod *= deriv_axis[axis][order][idx[axis]];
}
t3[[row, flat, axis_a, axis_b, axis_c]] = prod;
}
}
}
for axis in (0..d).rev() {
idx[axis] += 1;
if idx[axis] < axis_m {
break;
}
idx[axis] = 0;
}
}
}
Ok(t3)
}
}
#[derive(Debug, Clone)]
pub struct AffineCoordinateEvaluator {
pub latent_dim: usize,
}
impl AffineCoordinateEvaluator {
pub fn new(latent_dim: usize) -> Self {
Self { latent_dim }
}
}
impl SaeBasisEvaluator for AffineCoordinateEvaluator {
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
let expected = self.latent_dim + 1;
if n_basis != expected {
return Err(format!(
"AffineCoordinateEvaluator::phi_eta_split: n_basis {n_basis} != {expected}"
));
}
Ok(PhiEtaSplit::all_linear(n_basis))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
Some(<Self as SaeBasisThirdJet>::third_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"AffineCoordinateEvaluator: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let n = coords.nrows();
let m = self.latent_dim + 1;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, self.latent_dim));
phi.column_mut(0).fill(1.0);
for row in 0..n {
for axis in 0..self.latent_dim {
phi[[row, axis + 1]] = coords[[row, axis]];
jet[[row, axis + 1, axis]] = 1.0;
}
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for AffineCoordinateEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"AffineCoordinateEvaluator::second_jet: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let n = coords.nrows();
let m = self.latent_dim + 1;
let d = self.latent_dim;
Ok(Array4::<f64>::zeros((n, m, d, d)))
}
}
impl SaeBasisThirdJet for AffineCoordinateEvaluator {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"AffineCoordinateEvaluator::third_jet: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let n = coords.nrows();
let m = self.latent_dim + 1;
let d = self.latent_dim;
Ok(Array5::<f64>::zeros((n, m, d, d, d)))
}
}
#[derive(Debug, Clone)]
pub struct DuchonCoordinateEvaluator {
pub centers: Array2<f64>,
pub order: crate::basis::DuchonNullspaceOrder,
}
impl DuchonCoordinateEvaluator {
pub fn new(centers: Array2<f64>, m: usize) -> Result<Self, String> {
if centers.ncols() == 0 {
return Err("DuchonCoordinateEvaluator: centers must have at least one column".into());
}
if m == 0 {
return Err("DuchonCoordinateEvaluator: Duchon m must be at least 1".into());
}
let order = match m {
1 => crate::basis::DuchonNullspaceOrder::Zero,
2 => crate::basis::DuchonNullspaceOrder::Linear,
other => crate::basis::DuchonNullspaceOrder::Degree(other - 1),
};
Ok(Self { centers, order })
}
}
impl SaeBasisEvaluator for DuchonCoordinateEvaluator {
fn affine_transformed_evaluator(
&self,
shift: &[f64],
scale: &[f64],
n_basis: usize,
) -> Result<Option<Arc<dyn SaeBasisSecondJet>>, String> {
let dim = self.centers.ncols();
if shift.len() != dim || scale.len() != dim {
return Err(format!(
"DuchonCoordinateEvaluator::affine_transformed_evaluator: affine vectors must have length {dim}; got shift={} scale={}",
shift.len(),
scale.len()
));
}
if n_basis == usize::MAX {
return Err(
"DuchonCoordinateEvaluator::affine_transformed_evaluator: unreachable basis width"
.to_string(),
);
}
if dim != 1 {
return Ok(None);
}
if !(scale[0].is_finite() && scale[0] > 0.0 && shift[0].is_finite()) {
return Ok(None);
}
let mut centers = self.centers.clone();
for row in 0..centers.nrows() {
centers[[row, 0]] = (centers[[row, 0]] - shift[0]) / scale[0];
}
Ok(Some(Arc::new(Self {
centers,
order: self.order,
})))
}
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
let dim = self.centers.ncols();
let effective = duchon_effective_order_for_eta(self.centers.view(), self.order);
let n_poly = duchon_polynomial_column_count(dim, effective);
if n_basis < n_poly {
return Err(format!(
"DuchonCoordinateEvaluator::phi_eta_split: n_basis {n_basis} smaller than polynomial block {n_poly}"
));
}
let n_kernel = n_basis - n_poly;
let mut curved = vec![false; n_basis];
for col in 0..n_kernel {
curved[col] = true;
}
if let crate::basis::DuchonNullspaceOrder::Degree(degree) = effective {
let linear_mask = monomial_linear_mask(dim, degree);
if linear_mask.len() != n_poly {
return Err(format!(
"DuchonCoordinateEvaluator::phi_eta_split: polynomial mask width {} != {n_poly}",
linear_mask.len()
));
}
for (local_col, linear) in linear_mask.into_iter().enumerate() {
if !linear {
curved[n_kernel + local_col] = true;
}
}
}
Ok(PhiEtaSplit::from_curved_mask(curved))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
Some(<Self as SaeBasisThirdJet>::third_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != self.centers.ncols() {
return Err(format!(
"DuchonCoordinateEvaluator: expected latent_dim {}, got {}",
self.centers.ncols(),
coords.ncols()
));
}
crate::basis::duchon_sae_atom_basis_with_jet(coords, self.centers.view(), self.order)
.map_err(|err| err.to_string())
}
}
impl SaeBasisSecondJet for DuchonCoordinateEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
if coords.ncols() != self.centers.ncols() {
return Err(format!(
"DuchonCoordinateEvaluator::second_jet: expected latent_dim {}, got {}",
self.centers.ncols(),
coords.ncols()
));
}
crate::basis::duchon_sae_atom_second_jet(coords, self.centers.view(), self.order)
.map_err(|err| err.to_string())
}
}
impl SaeBasisThirdJet for DuchonCoordinateEvaluator {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String> {
if coords.ncols() != self.centers.ncols() {
return Err(format!(
"DuchonCoordinateEvaluator::third_jet: expected latent_dim {}, got {}",
self.centers.ncols(),
coords.ncols()
));
}
crate::basis::duchon_sae_atom_third_jet(coords, self.centers.view(), self.order)
.map_err(|err| err.to_string())
}
}
#[derive(Debug, Clone)]
pub struct EuclideanPatchEvaluator {
pub latent_dim: usize,
pub max_degree: usize,
}
impl EuclideanPatchEvaluator {
pub fn new(latent_dim: usize, max_degree: usize) -> Result<Self, String> {
if latent_dim == 0 {
return Err("EuclideanPatchEvaluator: latent_dim must be positive".into());
}
Ok(Self {
latent_dim,
max_degree,
})
}
pub fn basis_size(&self) -> usize {
crate::basis::monomial_exponents(self.latent_dim, self.max_degree).len()
}
fn order(&self) -> crate::basis::DuchonNullspaceOrder {
match self.max_degree {
0 => crate::basis::DuchonNullspaceOrder::Zero,
1 => crate::basis::DuchonNullspaceOrder::Linear,
k => crate::basis::DuchonNullspaceOrder::Degree(k),
}
}
}
impl SaeBasisEvaluator for EuclideanPatchEvaluator {
fn affine_transformed_evaluator(
&self,
shift: &[f64],
scale: &[f64],
n_basis: usize,
) -> Result<Option<Arc<dyn SaeBasisSecondJet>>, String> {
if shift.len() != self.latent_dim || scale.len() != self.latent_dim {
return Err(format!(
"EuclideanPatchEvaluator::affine_transformed_evaluator: affine vectors must have length {}; got shift={} scale={}",
self.latent_dim,
shift.len(),
scale.len()
));
}
if n_basis != self.basis_size() {
return Err(format!(
"EuclideanPatchEvaluator::affine_transformed_evaluator: n_basis {n_basis} != evaluator width {}",
self.basis_size()
));
}
if shift.iter().chain(scale.iter()).any(|v| !v.is_finite())
|| scale.iter().any(|&v| v <= 0.0)
{
return Ok(None);
}
Ok(Some(Arc::new(Self {
latent_dim: self.latent_dim,
max_degree: self.max_degree,
})))
}
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
let linear_mask = monomial_linear_mask(self.latent_dim, self.max_degree);
if linear_mask.len() != n_basis {
return Err(format!(
"EuclideanPatchEvaluator::phi_eta_split: polynomial mask width {} != n_basis {n_basis}",
linear_mask.len()
));
}
Ok(PhiEtaSplit::from_curved_mask(
linear_mask.into_iter().map(|linear| !linear).collect(),
))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
Some(<Self as SaeBasisThirdJet>::third_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"EuclideanPatchEvaluator: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let exponents = crate::basis::monomial_exponents(self.latent_dim, self.max_degree);
let n = coords.nrows();
let m = exponents.len();
let mut phi = Array2::<f64>::zeros((n, m));
for (col, alpha) in exponents.iter().enumerate() {
for row in 0..n {
let mut value = 1.0_f64;
for (axis, &exp) in alpha.iter().enumerate() {
if exp != 0 {
value *= coords[[row, axis]].powi(exp as i32);
}
}
phi[[row, col]] = value;
}
}
let jet = crate::basis::duchon_polynomial_first_derivative_nd(coords, self.order());
if jet.shape() != [n, m, self.latent_dim] {
return Err(format!(
"EuclideanPatchEvaluator: monomial jet shape {:?} disagrees with ({n}, {m}, {})",
jet.shape(),
self.latent_dim
));
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for EuclideanPatchEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"EuclideanPatchEvaluator::second_jet: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let exponents = crate::basis::monomial_exponents(self.latent_dim, self.max_degree);
let n = coords.nrows();
let m = exponents.len();
let d = self.latent_dim;
let mut hess = Array4::<f64>::zeros((n, m, d, d));
for (col, alpha) in exponents.iter().enumerate() {
for a in 0..d {
if alpha[a] == 0 {
continue;
}
for c in 0..d {
if a != c && alpha[c] == 0 {
continue;
}
let lead = if a == c {
(alpha[a] as f64) * (alpha[a].saturating_sub(1) as f64)
} else {
(alpha[a] as f64) * (alpha[c] as f64)
};
if lead == 0.0 {
continue;
}
for row in 0..n {
let mut value = lead;
for axis in 0..d {
let mut exp = alpha[axis];
if axis == a {
exp = exp.saturating_sub(1);
}
if axis == c {
exp = exp.saturating_sub(1);
}
if exp != 0 {
value *= coords[[row, axis]].powi(exp as i32);
}
}
hess[[row, col, a, c]] = value;
}
}
}
}
Ok(hess)
}
}
impl SaeBasisThirdJet for EuclideanPatchEvaluator {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String> {
if coords.ncols() != self.latent_dim {
return Err(format!(
"EuclideanPatchEvaluator::third_jet: expected latent_dim {}, got {}",
self.latent_dim,
coords.ncols()
));
}
let exponents = crate::basis::monomial_exponents(self.latent_dim, self.max_degree);
let n = coords.nrows();
let m = exponents.len();
let d = self.latent_dim;
let mut t3 = Array5::<f64>::zeros((n, m, d, d, d));
let falling = |alpha: usize, k: usize| -> f64 {
let mut acc = 1.0_f64;
for j in 0..k {
acc *= (alpha as f64) - (j as f64);
}
acc
};
for (col, alpha) in exponents.iter().enumerate() {
for a in 0..d {
if alpha[a] == 0 {
continue;
}
for b in 0..d {
for c in 0..d {
let mut order = vec![0usize; d];
order[a] += 1;
order[b] += 1;
order[c] += 1;
if (0..d).any(|axis| order[axis] > alpha[axis]) {
continue;
}
let mut lead = 1.0_f64;
for axis in 0..d {
lead *= falling(alpha[axis], order[axis]);
}
if lead == 0.0 {
continue;
}
for row in 0..n {
let mut value = lead;
for axis in 0..d {
let exp = alpha[axis] - order[axis];
if exp != 0 {
value *= coords[[row, axis]].powi(exp as i32);
}
}
t3[[row, col, a, b, c]] = value;
}
}
}
}
}
Ok(t3)
}
}
#[derive(Debug, Clone)]
pub struct CylinderHarmonicEvaluator {
pub circle_harmonics: usize,
pub line_degree: usize,
}
impl CylinderHarmonicEvaluator {
pub fn new(circle_harmonics: usize, line_degree: usize) -> Result<Self, String> {
if circle_harmonics == 0 {
return Err(
"CylinderHarmonicEvaluator requires circle_harmonics >= 1 (S¹ needs at least one \
harmonic pair)"
.to_string(),
);
}
Ok(Self {
circle_harmonics,
line_degree,
})
}
pub fn circle_basis_size(&self) -> usize {
2 * self.circle_harmonics + 1
}
pub fn line_basis_size(&self) -> usize {
self.line_degree + 1
}
pub fn basis_size(&self) -> usize {
self.circle_basis_size() * self.line_basis_size()
}
fn circle_tables(&self, t: f64) -> [Vec<f64>; 4] {
let mc = self.circle_basis_size();
let two_pi = 2.0 * std::f64::consts::PI;
let mut table = [
vec![0.0_f64; mc],
vec![0.0_f64; mc],
vec![0.0_f64; mc],
vec![0.0_f64; mc],
];
table[0][0] = 1.0;
for h in 1..=self.circle_harmonics {
let omega = two_pi * (h as f64);
let w2 = omega * omega;
let w3 = w2 * omega;
let angle = omega * t;
let s = angle.sin();
let c = angle.cos();
let s_idx = 2 * h - 1;
let c_idx = 2 * h;
table[0][s_idx] = s;
table[1][s_idx] = omega * c;
table[2][s_idx] = -w2 * s;
table[3][s_idx] = -w3 * c;
table[0][c_idx] = c;
table[1][c_idx] = -omega * s;
table[2][c_idx] = -w2 * c;
table[3][c_idx] = w3 * s;
}
table
}
fn line_tables(&self, t: f64) -> [Vec<f64>; 4] {
let ml = self.line_basis_size();
let mut table = [
vec![0.0_f64; ml],
vec![0.0_f64; ml],
vec![0.0_f64; ml],
vec![0.0_f64; ml],
];
for j in 0..ml {
for k in 0..4 {
if k > j {
table[k][j] = 0.0;
continue;
}
let mut coeff = 1.0_f64;
for q in 0..k {
coeff *= (j - q) as f64;
}
let residual = j - k;
let pow = if residual == 0 {
1.0
} else {
t.powi(residual as i32)
};
table[k][j] = coeff * pow;
}
}
table
}
pub fn roughness_gram(&self) -> Array2<f64> {
let mc = self.circle_basis_size();
let ml = self.line_basis_size();
let two_pi = 2.0 * std::f64::consts::PI;
let mut gc = Array2::<f64>::zeros((mc, mc));
let mut sc = Array2::<f64>::zeros((mc, mc));
gc[[0, 0]] = 1.0; for h in 1..=self.circle_harmonics {
let omega = two_pi * (h as f64);
let w4 = omega.powi(4);
let s_idx = 2 * h - 1;
let c_idx = 2 * h;
gc[[s_idx, s_idx]] = 0.5;
gc[[c_idx, c_idx]] = 0.5;
sc[[s_idx, s_idx]] = w4 * 0.5;
sc[[c_idx, c_idx]] = w4 * 0.5;
}
let mut gl = Array2::<f64>::zeros((ml, ml));
let mut sl = Array2::<f64>::zeros((ml, ml));
for i in 0..ml {
for j in 0..ml {
gl[[i, j]] = 1.0 / ((i + j + 1) as f64);
if i >= 2 && j >= 2 {
let ci = (i * (i - 1)) as f64;
let cj = (j * (j - 1)) as f64;
let exp = (i - 2) + (j - 2);
sl[[i, j]] = ci * cj / ((exp + 1) as f64);
}
}
}
let m = mc * ml;
let mut s = Array2::<f64>::zeros((m, m));
for ca in 0..mc {
for la in 0..ml {
let row = ca * ml + la;
for cb in 0..mc {
for lb in 0..ml {
let col = cb * ml + lb;
s[[row, col]] = sc[[ca, cb]] * gl[[la, lb]] + gc[[ca, cb]] * sl[[la, lb]];
}
}
}
}
s
}
fn check_coords(&self, coords: ArrayView2<'_, f64>, what: &str) -> Result<(), String> {
if coords.ncols() != 2 {
return Err(format!(
"CylinderHarmonicEvaluator::{what}: expected latent_dim == 2 (S¹ × ℝ), got {}",
coords.ncols()
));
}
Ok(())
}
}
impl SaeBasisEvaluator for CylinderHarmonicEvaluator {
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
let expected = self.basis_size();
if n_basis != expected {
return Err(format!(
"CylinderHarmonicEvaluator::phi_eta_split: n_basis {n_basis} != evaluator width {expected}"
));
}
let ml = self.line_basis_size();
let mut curved = vec![false; expected];
for c in 0..self.circle_basis_size() {
for l in 0..ml {
let circle_curved = c > 2;
let line_curved = l > 1;
curved[c * ml + l] = circle_curved || line_curved;
}
}
Ok(PhiEtaSplit::from_curved_mask(curved))
}
fn factor_basis_sizes(&self) -> Option<(usize, usize)> {
Some((self.circle_basis_size(), self.line_basis_size()))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
Some(<Self as SaeBasisThirdJet>::third_jet(self, coords))
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
self.check_coords(coords, "evaluate")?;
let n = coords.nrows();
let mc = self.circle_basis_size();
let ml = self.line_basis_size();
let m = mc * ml;
let mut phi = Array2::<f64>::zeros((n, m));
let mut jet = Array3::<f64>::zeros((n, m, 2));
for row in 0..n {
let t0 = coords[[row, 0]];
let t1 = coords[[row, 1]];
let circ = self.circle_tables(t0);
let line = self.line_tables(t1);
for c in 0..mc {
for l in 0..ml {
let col = c * ml + l;
phi[[row, col]] = circ[0][c] * line[0][l];
jet[[row, col, 0]] = circ[1][c] * line[0][l];
jet[[row, col, 1]] = circ[0][c] * line[1][l];
}
}
}
Ok((phi, jet))
}
}
impl SaeBasisSecondJet for CylinderHarmonicEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
self.check_coords(coords, "second_jet")?;
let n = coords.nrows();
let mc = self.circle_basis_size();
let ml = self.line_basis_size();
let m = mc * ml;
let mut h = Array4::<f64>::zeros((n, m, 2, 2));
for row in 0..n {
let t0 = coords[[row, 0]];
let t1 = coords[[row, 1]];
let circ = self.circle_tables(t0);
let line = self.line_tables(t1);
for c in 0..mc {
for l in 0..ml {
let col = c * ml + l;
h[[row, col, 0, 0]] = circ[2][c] * line[0][l];
h[[row, col, 1, 1]] = circ[0][c] * line[2][l];
let mixed = circ[1][c] * line[1][l];
h[[row, col, 0, 1]] = mixed;
h[[row, col, 1, 0]] = mixed;
}
}
}
Ok(h)
}
}
impl SaeBasisThirdJet for CylinderHarmonicEvaluator {
fn third_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array5<f64>, String> {
self.check_coords(coords, "third_jet")?;
let n = coords.nrows();
let mc = self.circle_basis_size();
let ml = self.line_basis_size();
let m = mc * ml;
let mut t3 = Array5::<f64>::zeros((n, m, 2, 2, 2));
for row in 0..n {
let t0 = coords[[row, 0]];
let t1 = coords[[row, 1]];
let circ = self.circle_tables(t0);
let line = self.line_tables(t1);
for c in 0..mc {
for l in 0..ml {
let col = c * ml + l;
for a in 0..2 {
for b in 0..2 {
for e in 0..2 {
let k0 = (a == 0) as usize + (b == 0) as usize + (e == 0) as usize;
let k1 = 3 - k0;
t3[[row, col, a, b, e]] = circ[k0][c] * line[k1][l];
}
}
}
}
}
}
Ok(t3)
}
}
#[derive(Debug, Clone)]
pub struct SubspaceReducedEvaluator {
inner: Arc<dyn SaeBasisSecondJet>,
q: Array2<f64>,
}
impl SubspaceReducedEvaluator {
pub fn new(inner: Arc<dyn SaeBasisSecondJet>, q: Array2<f64>) -> Result<Self, String> {
if q.nrows() == 0 || q.ncols() == 0 {
return Err(format!(
"SubspaceReducedEvaluator: column map must be non-empty; got {:?}",
q.dim()
));
}
if q.ncols() > q.nrows() {
return Err(format!(
"SubspaceReducedEvaluator: retained rank {} exceeds inner basis width {}",
q.ncols(),
q.nrows()
));
}
Ok(Self { inner, q })
}
pub fn inner_width(&self) -> usize {
self.q.nrows()
}
pub fn reduced_width(&self) -> usize {
self.q.ncols()
}
fn check_inner_width(&self, got: usize, what: &str) -> Result<(), String> {
if got != self.q.nrows() {
return Err(format!(
"SubspaceReducedEvaluator::{what}: inner evaluator returned width {got}, \
column map expects {}",
self.q.nrows()
));
}
Ok(())
}
}
fn remix_cols_2(phi: &Array2<f64>, q: &Array2<f64>) -> Array2<f64> {
phi.dot(q)
}
fn remix_cols_along_basis(
jet: ndarray::ArrayViewD<'_, f64>,
q: &Array2<f64>,
) -> Result<ndarray::ArrayD<f64>, String> {
let shape = jet.shape().to_vec();
if shape.len() < 2 {
return Err(format!(
"SubspaceReducedEvaluator: jet must have at least (n, M) axes; got {shape:?}"
));
}
let n = shape[0];
let m = shape[1];
if m != q.nrows() {
return Err(format!(
"SubspaceReducedEvaluator: jet basis axis {m} != column-map rows {}",
q.nrows()
));
}
let r = q.ncols();
let trailing: usize = shape[2..].iter().product::<usize>().max(1);
let mut out_shape = shape.clone();
out_shape[1] = r;
let jet_std = jet.to_owned();
let jet_flat = jet_std
.to_shape((n, m, trailing))
.map_err(|err| format!("SubspaceReducedEvaluator: jet reshape failed: {err}"))?;
let mut out_flat = Array3::<f64>::zeros((n, r, trailing));
for row in 0..n {
for t in 0..trailing {
for rc in 0..r {
let mut acc = 0.0_f64;
for mc in 0..m {
acc += jet_flat[[row, mc, t]] * q[[mc, rc]];
}
out_flat[[row, rc, t]] = acc;
}
}
}
let out = out_flat
.into_shape_with_order(ndarray::IxDyn(&out_shape))
.map_err(|err| format!("SubspaceReducedEvaluator: out reshape failed: {err}"))?;
Ok(out)
}
impl SaeBasisEvaluator for SubspaceReducedEvaluator {
fn phi_eta_split(&self, n_basis: usize) -> Result<PhiEtaSplit, String> {
if n_basis != self.q.ncols() {
return Err(format!(
"SubspaceReducedEvaluator::phi_eta_split: n_basis {n_basis} != reduced width {}",
self.q.ncols()
));
}
let inner_split = self.inner.phi_eta_split(self.q.nrows())?;
let mut inner_curved = vec![false; self.q.nrows()];
for &col in &inner_split.curved_cols {
if col < inner_curved.len() {
inner_curved[col] = true;
}
}
let mut curved = vec![false; self.q.ncols()];
for rc in 0..self.q.ncols() {
for mc in 0..self.q.nrows() {
if inner_curved[mc] && self.q[[mc, rc]] != 0.0 {
curved[rc] = true;
break;
}
}
}
Ok(PhiEtaSplit::from_curved_mask(curved))
}
fn second_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array4<f64>, String>> {
Some(<Self as SaeBasisSecondJet>::second_jet(self, coords))
}
fn third_jet_dyn(&self, coords: ArrayView2<'_, f64>) -> Option<Result<Array5<f64>, String>> {
match self.inner.third_jet_dyn(coords) {
Some(Ok(t3)) => {
if let Err(err) = self.check_inner_width(t3.shape()[1], "third_jet_dyn") {
return Some(Err(err));
}
Some(
remix_cols_along_basis(t3.view().into_dyn(), &self.q).and_then(|out| {
out.into_dimensionality::<ndarray::Ix5>().map_err(|err| {
format!("SubspaceReducedEvaluator: third jet dim: {err}")
})
}),
)
}
Some(Err(err)) => Some(Err(err)),
None => None,
}
}
fn evaluate(&self, coords: ArrayView2<'_, f64>) -> Result<(Array2<f64>, Array3<f64>), String> {
let (phi, jet) = self.inner.evaluate(coords)?;
self.check_inner_width(phi.ncols(), "evaluate")?;
let phi_red = remix_cols_2(&phi, &self.q);
let jet_red = remix_cols_along_basis(jet.view().into_dyn(), &self.q)?
.into_dimensionality::<ndarray::Ix3>()
.map_err(|err| format!("SubspaceReducedEvaluator: jet dim: {err}"))?;
Ok((phi_red, jet_red))
}
}
impl SaeBasisSecondJet for SubspaceReducedEvaluator {
fn second_jet(&self, coords: ArrayView2<'_, f64>) -> Result<Array4<f64>, String> {
let h = self.inner.second_jet(coords)?;
self.check_inner_width(h.shape()[1], "second_jet")?;
remix_cols_along_basis(h.view().into_dyn(), &self.q)?
.into_dimensionality::<ndarray::Ix4>()
.map_err(|err| format!("SubspaceReducedEvaluator: second jet dim: {err}"))
}
}