use faer::Side as FaerSide;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::linalg::faer_ndarray::{FaerCholesky, fast_ab, fast_ata, fast_atb};
use crate::terms::sae::manifold::{SaeBasisEvaluator, solve_design_least_squares};
const TURNING_QUADRATURE_CELLS: usize = 256;
pub const ARC_LENGTH_GRID_CELLS: usize = 2048;
pub const CHART_RECOMPOSITION_REL_TOL: f64 = 1.0e-9;
#[derive(Debug, Clone, PartialEq)]
pub enum CanonicalChartTopology {
Circle { period: f64 },
Interval,
}
#[derive(Debug, Clone)]
pub struct UnitSpeedReparameterization {
pub new_row_coords: Array1<f64>,
pub new_decoder: Array2<f64>,
pub decoder_transport: Array2<f64>,
pub total_arc_length: f64,
pub recomposition_residual: f64,
}
fn curve_speeds(
jet: &ndarray::Array3<f64>,
decoder: ArrayView2<'_, f64>,
) -> Result<Vec<f64>, String> {
let (rows, m, d) = jet.dim();
if d != 1 {
return Err(format!(
"sae_chart_canonicalization: expected a 1-D latent jet, got latent_dim {d}"
));
}
if decoder.nrows() != m {
return Err(format!(
"sae_chart_canonicalization: jet basis width {m} != decoder rows {}",
decoder.nrows()
));
}
let p = decoder.ncols();
let mut speeds = Vec::with_capacity(rows);
let mut tangent = vec![0.0_f64; p];
for row in 0..rows {
for slot in tangent.iter_mut() {
*slot = 0.0;
}
for bm in 0..m {
let dphi = jet[[row, bm, 0]];
if dphi == 0.0 {
continue;
}
for (j, slot) in tangent.iter_mut().enumerate() {
*slot += dphi * decoder[[bm, j]];
}
}
speeds.push(tangent.iter().map(|v| v * v).sum::<f64>().sqrt());
}
Ok(speeds)
}
fn partial_cell_arc(f0: f64, fm: f64, f1: f64, h: f64, x: f64) -> f64 {
if h <= 0.0 {
return 0.0;
}
let a = (2.0 * f0 - 4.0 * fm + 2.0 * f1) / (h * h);
let b = (-3.0 * f0 + 4.0 * fm - f1) / h;
let x2 = x * x;
a * x2 * x / 3.0 + b * x2 / 2.0 + f0 * x
}
pub fn unit_speed_reparameterization(
evaluator: &dyn SaeBasisEvaluator,
decoder: ArrayView2<'_, f64>,
row_coords: ArrayView1<'_, f64>,
topology: &CanonicalChartTopology,
) -> Result<Option<UnitSpeedReparameterization>, String> {
let n = row_coords.len();
let m = decoder.nrows();
let p = decoder.ncols();
if n == 0 || m == 0 || p == 0 {
return Ok(None);
}
for &t in row_coords.iter() {
if !t.is_finite() {
return Ok(None);
}
}
let (lo, hi, span) = match topology {
CanonicalChartTopology::Circle { period } => {
if !(period.is_finite() && *period > 0.0) {
return Err(format!(
"sae_chart_canonicalization: circle period must be finite and positive; got {period}"
));
}
(0.0, *period, *period)
}
CanonicalChartTopology::Interval => {
let mut t_min = f64::INFINITY;
let mut t_max = f64::NEG_INFINITY;
for &t in row_coords.iter() {
t_min = t_min.min(t);
t_max = t_max.max(t);
}
let scale = t_min.abs().max(t_max.abs()).max(1.0);
if !(t_max - t_min > 1.0e-12 * scale) {
return Ok(None);
}
(t_min, t_max, 1.0)
}
};
let cells = ARC_LENGTH_GRID_CELLS;
let h = (hi - lo) / cells as f64;
let mut quad_coords = Array2::<f64>::zeros((2 * cells + 1, 1));
for j in 0..=cells {
quad_coords[[2 * j, 0]] = lo + j as f64 * h;
if j < cells {
quad_coords[[2 * j + 1, 0]] = lo + (j as f64 + 0.5) * h;
}
}
let (grid_phi_all, grid_jet_all) = evaluator.evaluate(quad_coords.view())?;
if grid_phi_all.ncols() != m {
return Err(format!(
"sae_chart_canonicalization: evaluator basis width {} != decoder rows {m}",
grid_phi_all.ncols()
));
}
let speeds = curve_speeds(&grid_jet_all, decoder)?;
if speeds.iter().any(|s| !s.is_finite()) {
return Ok(None);
}
let mut cumulative = vec![0.0_f64; cells + 1];
for j in 0..cells {
let f0 = speeds[2 * j];
let fm = speeds[2 * j + 1];
let f1 = speeds[2 * j + 2];
cumulative[j + 1] = cumulative[j] + h * (f0 + 4.0 * fm + f1) / 6.0;
}
let total = cumulative[cells];
if !(total.is_finite() && total > 0.0) {
return Ok(None);
}
let rescale = span / total;
let map_coord = |t: f64| -> f64 {
let local = match topology {
CanonicalChartTopology::Circle { period } => (t - lo).rem_euclid(*period),
CanonicalChartTopology::Interval => (t - lo).clamp(0.0, hi - lo),
};
let cell = ((local / h).floor() as usize).min(cells - 1);
let x = local - cell as f64 * h;
let s = cumulative[cell]
+ partial_cell_arc(
speeds[2 * cell],
speeds[2 * cell + 1],
speeds[2 * cell + 2],
h,
x,
);
let mapped = rescale * s;
match topology {
CanonicalChartTopology::Circle { period } => mapped.rem_euclid(*period),
CanonicalChartTopology::Interval => mapped.clamp(0.0, span),
}
};
let new_row_coords = Array1::from_iter(row_coords.iter().map(|&t| map_coord(t)));
let mut node_new_coords = Array2::<f64>::zeros((cells + 1, 1));
let mut old_phi = Array2::<f64>::zeros((cells + 1, m));
for j in 0..=cells {
node_new_coords[[j, 0]] = map_coord(lo + j as f64 * h);
for bm in 0..m {
old_phi[[j, bm]] = grid_phi_all[[2 * j, bm]];
}
}
let Some(recomposition) =
recompose_decoder_exact_ls(evaluator, decoder, old_phi.view(), node_new_coords.view())?
else {
return Ok(None);
};
Ok(Some(UnitSpeedReparameterization {
new_row_coords,
new_decoder: recomposition.new_decoder,
decoder_transport: recomposition.transport,
total_arc_length: total,
recomposition_residual: recomposition.recomposition_residual,
}))
}
pub(crate) struct DecoderRecomposition {
pub transport: Array2<f64>,
pub new_decoder: Array2<f64>,
pub recomposition_residual: f64,
}
pub(crate) fn recompose_decoder_exact_ls(
evaluator: &dyn SaeBasisEvaluator,
decoder: ArrayView2<'_, f64>,
old_phi: ArrayView2<'_, f64>,
new_coords: ArrayView2<'_, f64>,
) -> Result<Option<DecoderRecomposition>, String> {
let m = decoder.nrows();
let (new_phi, new_jet) = evaluator.evaluate(new_coords)?;
if new_phi.ncols() != m
|| new_phi.nrows() != old_phi.nrows()
|| new_jet.dim() != (new_coords.nrows(), m, new_coords.ncols())
{
return Err(format!(
"sae_chart_canonicalization: evaluator returned basis {:?} / jet {:?} at the canonical grid; expected ({}, {m}) with latent_dim {}",
new_phi.dim(),
new_jet.dim(),
old_phi.nrows(),
new_coords.ncols()
));
}
let transport = solve_design_least_squares(new_phi.view(), old_phi)?;
let new_decoder = fast_ab(&transport, &decoder);
let old_fit = fast_ab(&old_phi, &decoder);
let new_fit = fast_ab(&new_phi, &new_decoder);
let mut fit_scale = 0.0_f64;
let mut max_abs = 0.0_f64;
for (a, b) in old_fit.iter().zip(new_fit.iter()) {
fit_scale = fit_scale.max(a.abs()).max(b.abs());
max_abs = max_abs.max((a - b).abs());
}
if !(fit_scale.is_finite() && fit_scale > 0.0 && max_abs.is_finite()) {
return Ok(None);
}
let recomposition_residual = max_abs / fit_scale;
if recomposition_residual > CHART_RECOMPOSITION_REL_TOL {
return Ok(None);
}
Ok(Some(DecoderRecomposition {
transport,
new_decoder,
recomposition_residual,
}))
}
pub const TORUS_FLOW_MAX_HARMONIC: i32 = 2;
pub const SAE_FLOW_DIFFEO_MIN_DET: f64 = 0.1;
pub const TORUS_FLOW_DIFFEO_MIN_DET: f64 = SAE_FLOW_DIFFEO_MIN_DET;
pub const TORUS_FLOW_GUARD_NODES_PER_AXIS: usize = 64;
pub const TORUS_FLOW_GN_MAX_ITERS: usize = 80;
pub const TORUS_FLOW_GN_MAX_REJECTS: usize = 12;
pub const TORUS_TRANSPORT_MIN_NODES_PER_AXIS: usize = 48;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TorusFlowModeKey {
pub component: usize,
pub freq: (i32, i32),
pub is_cos: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct FlowModeSample {
pub component: usize,
pub value: f64,
pub grad: [f64; 2],
}
#[derive(Debug, Clone)]
pub struct TorusFlowBasis {
pub period: f64,
freqs: Vec<(i32, i32)>,
}
impl TorusFlowBasis {
pub fn new(period: f64) -> Result<Self, String> {
if !(period.is_finite() && period > 0.0) {
return Err(format!(
"TorusFlowBasis: period must be finite and positive; got {period}"
));
}
let h = TORUS_FLOW_MAX_HARMONIC;
let mut freqs = Vec::new();
for a in -h..=h {
for b in -h..=h {
if a > 0 || (a == 0 && b > 0) {
freqs.push((a, b));
}
}
}
Ok(Self { period, freqs })
}
pub fn dim(&self) -> usize {
4 * self.freqs.len()
}
pub fn mode_layout(&self) -> Vec<TorusFlowModeKey> {
let mut keys = Vec::with_capacity(self.dim());
for component in 0..2 {
for &freq in &self.freqs {
keys.push(TorusFlowModeKey {
component,
freq,
is_cos: false,
});
keys.push(TorusFlowModeKey {
component,
freq,
is_cos: true,
});
}
}
keys
}
pub fn mode_samples(&self, t: [f64; 2]) -> Vec<FlowModeSample> {
let tau = std::f64::consts::TAU;
let mut out = Vec::with_capacity(self.dim());
for component in 0..2 {
for &(a, b) in &self.freqs {
let w0 = tau * a as f64 / self.period;
let w1 = tau * b as f64 / self.period;
let angle = w0 * t[0] + w1 * t[1];
let s = angle.sin();
let c = angle.cos();
out.push(FlowModeSample {
component,
value: s,
grad: [w0 * c, w1 * c],
});
out.push(FlowModeSample {
component,
value: c,
grad: [-w0 * s, -w1 * s],
});
}
}
out
}
pub fn map_point(&self, theta: &[f64], t: [f64; 2]) -> [f64; 2] {
assert_eq!(theta.len(), self.dim(), "TorusFlowBasis: theta length");
let mut out = t;
for (coef, sample) in theta.iter().zip(self.mode_samples(t)) {
out[sample.component] += coef * sample.value;
}
[
out[0].rem_euclid(self.period),
out[1].rem_euclid(self.period),
]
}
pub fn flow_jacobian(&self, theta: &[f64], t: [f64; 2]) -> [[f64; 2]; 2] {
assert_eq!(theta.len(), self.dim(), "TorusFlowBasis: theta length");
let mut jac = [[1.0, 0.0], [0.0, 1.0]];
for (coef, sample) in theta.iter().zip(self.mode_samples(t)) {
jac[sample.component][0] += coef * sample.grad[0];
jac[sample.component][1] += coef * sample.grad[1];
}
jac
}
pub fn min_jacobian_det_on_grid(&self, theta: &[f64]) -> f64 {
let nodes = TORUS_FLOW_GUARD_NODES_PER_AXIS;
let mut min_det = f64::INFINITY;
for i in 0..nodes {
for j in 0..nodes {
let t = [
self.period * i as f64 / nodes as f64,
self.period * j as f64 / nodes as f64,
];
let jac = self.flow_jacobian(theta, t);
let det = jac[0][0] * jac[1][1] - jac[0][1] * jac[1][0];
min_det = min_det.min(det);
}
}
min_det
}
}
#[derive(Debug, Clone)]
pub struct TorusIsometryFlowReparameterization {
pub new_row_coords: Array2<f64>,
pub new_decoder: Array2<f64>,
pub decoder_transport: Array2<f64>,
pub flow_theta: Vec<f64>,
pub defect_initial: f64,
pub defect_final: f64,
pub profiled_metric_scale: f64,
pub min_flow_jacobian_det: f64,
pub recomposition_residual: f64,
}
struct FlowObjectiveState {
defect: f64,
scale: f64,
a_rows: Vec<[f64; 4]>,
}
fn evaluate_flow_defect(
theta: &[f64],
row_modes: &[Vec<FlowModeSample>],
row_base: &[[f64; 4]],
ghat: &[[f64; 3]],
ghat_norm_sq: f64,
) -> Option<FlowObjectiveState> {
let n = row_modes.len();
let mut a_rows = Vec::with_capacity(n);
let mut cross = 0.0_f64;
for (modes, base) in row_modes.iter().zip(row_base.iter()) {
let mut a = *base;
for (coef, sample) in theta.iter().zip(modes.iter()) {
a[2 * sample.component] += coef * sample.grad[0];
a[2 * sample.component + 1] += coef * sample.grad[1];
}
a_rows.push(a);
}
for (a, g) in a_rows.iter().zip(ghat.iter()) {
let m00 = a[0] * a[0] + a[2] * a[2];
let m11 = a[1] * a[1] + a[3] * a[3];
let m01 = a[0] * a[1] + a[2] * a[3];
cross += m00 * g[0] + m11 * g[1] + 2.0 * m01 * g[2];
}
let scale = cross / ghat_norm_sq;
if !(scale.is_finite() && scale > 0.0) {
return None;
}
let mut defect = 0.0_f64;
for (a, g) in a_rows.iter().zip(ghat.iter()) {
let m00 = a[0] * a[0] + a[2] * a[2];
let m11 = a[1] * a[1] + a[3] * a[3];
let m01 = a[0] * a[1] + a[2] * a[3];
let r00 = m00 - scale * g[0];
let r11 = m11 - scale * g[1];
let r01 = m01 - scale * g[2];
defect += r00 * r00 + r11 * r11 + 2.0 * r01 * r01;
}
if !defect.is_finite() {
return None;
}
Some(FlowObjectiveState {
defect,
scale,
a_rows,
})
}
struct FlowMinimization {
theta: Vec<f64>,
defect_initial: f64,
defect_final: f64,
profiled_scale: f64,
}
fn minimize_isometry_defect_flow(
row_modes: &[Vec<FlowModeSample>],
row_base: &[[f64; 4]],
ghat: &[[f64; 3]],
ghat_norm_sq: f64,
q: usize,
min_det: f64,
min_det_on_grid: &dyn Fn(&[f64]) -> f64,
) -> Option<FlowMinimization> {
let n = row_modes.len();
let mut theta = vec![0.0_f64; q];
let mut state = evaluate_flow_defect(&theta, row_modes, row_base, ghat, ghat_norm_sq)?;
let defect_initial = state.defect;
if !(defect_initial > 0.0) {
return None;
}
let sqrt2 = std::f64::consts::SQRT_2;
let mut lambda = 1.0e-4_f64;
let mut any_accepted = false;
for iteration in 0..TORUS_FLOW_GN_MAX_ITERS {
if iteration + 1 == TORUS_FLOW_GN_MAX_ITERS {
break;
}
let mut jmat = Array2::<f64>::zeros((3 * n, q));
let mut rcol = Array2::<f64>::zeros((3 * n, 1));
for (i, (a, g)) in state.a_rows.iter().zip(ghat.iter()).enumerate() {
let m00 = a[0] * a[0] + a[2] * a[2];
let m11 = a[1] * a[1] + a[3] * a[3];
let m01 = a[0] * a[1] + a[2] * a[3];
rcol[[3 * i, 0]] = m00 - state.scale * g[0];
rcol[[3 * i + 1, 0]] = m11 - state.scale * g[1];
rcol[[3 * i + 2, 0]] = sqrt2 * (m01 - state.scale * g[2]);
for (k, sample) in row_modes[i].iter().enumerate() {
let ac0 = a[2 * sample.component];
let ac1 = a[2 * sample.component + 1];
let s00 = 2.0 * sample.grad[0] * ac0;
let s11 = 2.0 * sample.grad[1] * ac1;
let s01 = sample.grad[0] * ac1 + sample.grad[1] * ac0;
jmat[[3 * i, k]] = s00;
jmat[[3 * i + 1, k]] = s11;
jmat[[3 * i + 2, k]] = sqrt2 * s01;
}
}
let jtj = fast_ata(&jmat);
let jtr = fast_atb(&jmat, &rcol);
let mut rejects = 0usize;
let mut accepted_step = false;
let mut converged = false;
let mut step_norm_sq = 0.0_f64;
while rejects < TORUS_FLOW_GN_MAX_REJECTS {
let mut damped = jtj.clone();
for d in 0..q {
damped[[d, d]] += lambda * (1.0 + jtj[[d, d]]);
}
let factor = match damped.cholesky(FaerSide::Lower) {
Ok(factor) => factor,
Err(_) => {
lambda *= 10.0;
rejects += 1;
continue;
}
};
let mut neg_jtr = jtr.clone();
neg_jtr.mapv_inplace(|v| -v);
let delta = factor.solve_mat(&neg_jtr);
let mut candidate = theta.clone();
step_norm_sq = 0.0;
for k in 0..q {
candidate[k] += delta[[k, 0]];
step_norm_sq += delta[[k, 0]] * delta[[k, 0]];
}
let folded = min_det_on_grid(&candidate) <= min_det;
let candidate_state = if folded {
None
} else {
evaluate_flow_defect(&candidate, row_modes, row_base, ghat, ghat_norm_sq)
};
match candidate_state {
Some(next) if next.defect < state.defect => {
let improvement = state.defect - next.defect;
theta = candidate;
state = next;
any_accepted = true;
accepted_step = true;
lambda = (lambda / 10.0).max(1.0e-12);
if improvement <= 1.0e-14 * (1.0 + state.defect) {
converged = true;
}
break;
}
Some(..) | None => {
lambda *= 10.0;
rejects += 1;
}
}
}
if !accepted_step {
break;
}
if converged {
break;
}
let theta_norm_sq: f64 = theta.iter().map(|v| v * v).sum();
if step_norm_sq <= 1.0e-24 * (1.0 + theta_norm_sq) {
break;
}
}
if !any_accepted || !(state.defect < defect_initial) {
return None;
}
Some(FlowMinimization {
theta,
defect_initial,
defect_final: state.defect,
profiled_scale: state.scale,
})
}
fn extract_pullback_metric_d2(
label: &str,
evaluator: &dyn SaeBasisEvaluator,
decoder: ArrayView2<'_, f64>,
row_coords: ArrayView2<'_, f64>,
) -> Result<Option<(Vec<[f64; 3]>, f64)>, String> {
let n = row_coords.nrows();
let m = decoder.nrows();
let p = decoder.ncols();
let (row_phi, row_jet) = evaluator.evaluate(row_coords)?;
if row_phi.ncols() != m || row_jet.dim() != (n, m, 2) {
return Err(format!(
"{label}: evaluator returned basis {:?} / jet {:?}; expected width {m}, latent_dim 2",
row_phi.dim(),
row_jet.dim()
));
}
let mut g_rows: Vec<[f64; 3]> = Vec::with_capacity(n);
let mut log_det_sum = 0.0_f64;
let mut tangent0 = vec![0.0_f64; p];
let mut tangent1 = vec![0.0_f64; p];
for row in 0..n {
for slot in tangent0.iter_mut() {
*slot = 0.0;
}
for slot in tangent1.iter_mut() {
*slot = 0.0;
}
for bm in 0..m {
let d0 = row_jet[[row, bm, 0]];
let d1 = row_jet[[row, bm, 1]];
if d0 == 0.0 && d1 == 0.0 {
continue;
}
for j in 0..p {
let b = decoder[[bm, j]];
tangent0[j] += d0 * b;
tangent1[j] += d1 * b;
}
}
let mut g00 = 0.0_f64;
let mut g11 = 0.0_f64;
let mut g01 = 0.0_f64;
for j in 0..p {
g00 += tangent0[j] * tangent0[j];
g11 += tangent1[j] * tangent1[j];
g01 += tangent0[j] * tangent1[j];
}
let det = g00 * g11 - g01 * g01;
if !(det.is_finite() && det > 0.0) {
return Ok(None);
}
log_det_sum += 0.5 * det.ln();
g_rows.push([g00, g11, g01]);
}
let g_bar = (log_det_sum / n as f64).exp();
if !(g_bar.is_finite() && g_bar > 0.0) {
return Ok(None);
}
Ok(Some((g_rows, g_bar)))
}
fn flat_normalized_metric(g_rows: &[[f64; 3]], g_bar: f64) -> Option<(Vec<[f64; 3]>, f64)> {
let mut ghat: Vec<[f64; 3]> = Vec::with_capacity(g_rows.len());
let mut ghat_norm_sq = 0.0_f64;
for g in g_rows {
let h = [g[0] / g_bar, g[1] / g_bar, g[2] / g_bar];
ghat_norm_sq += h[0] * h[0] + h[1] * h[1] + 2.0 * h[2] * h[2];
ghat.push(h);
}
if !(ghat_norm_sq.is_finite() && ghat_norm_sq > 0.0) {
return None;
}
Some((ghat, ghat_norm_sq))
}
pub fn torus_isometry_flow_reparameterization(
evaluator: &dyn SaeBasisEvaluator,
decoder: ArrayView2<'_, f64>,
row_coords: ArrayView2<'_, f64>,
period: f64,
) -> Result<Option<TorusIsometryFlowReparameterization>, String> {
let n = row_coords.nrows();
let m = decoder.nrows();
let p = decoder.ncols();
if row_coords.ncols() != 2 {
return Err(format!(
"torus_isometry_flow_reparameterization: expected (n, 2) row coordinates; got {:?}",
row_coords.dim()
));
}
if n == 0 || m == 0 || p == 0 {
return Ok(None);
}
for &t in row_coords.iter() {
if !t.is_finite() {
return Ok(None);
}
}
let Some((g_rows, g_bar)) = extract_pullback_metric_d2(
"torus_isometry_flow_reparameterization",
evaluator,
decoder,
row_coords,
)?
else {
return Ok(None);
};
let Some((ghat, ghat_norm_sq)) = flat_normalized_metric(&g_rows, g_bar) else {
return Ok(None);
};
let basis = TorusFlowBasis::new(period)?;
let q = basis.dim();
let mut row_modes: Vec<Vec<FlowModeSample>> = Vec::with_capacity(n);
for row in 0..n {
row_modes.push(basis.mode_samples([row_coords[[row, 0]], row_coords[[row, 1]]]));
}
let row_base = vec![[1.0_f64, 0.0, 0.0, 1.0]; n];
let Some(minimization) = minimize_isometry_defect_flow(
&row_modes,
&row_base,
&ghat,
ghat_norm_sq,
q,
TORUS_FLOW_DIFFEO_MIN_DET,
&|candidate: &[f64]| basis.min_jacobian_det_on_grid(candidate),
) else {
return Ok(None);
};
let theta = minimization.theta;
let defect_initial = minimization.defect_initial;
let min_flow_jacobian_det = basis.min_jacobian_det_on_grid(&theta);
if !(min_flow_jacobian_det > TORUS_FLOW_DIFFEO_MIN_DET) {
return Ok(None);
}
let axis_nodes = TORUS_TRANSPORT_MIN_NODES_PER_AXIS.max(3 * (m as f64).sqrt().ceil() as usize);
let grid_rows = axis_nodes * axis_nodes;
let mut grid = Array2::<f64>::zeros((grid_rows, 2));
let mut new_grid = Array2::<f64>::zeros((grid_rows, 2));
for i in 0..axis_nodes {
for j in 0..axis_nodes {
let idx = i * axis_nodes + j;
let u = [
period * i as f64 / axis_nodes as f64,
period * j as f64 / axis_nodes as f64,
];
grid[[idx, 0]] = u[0];
grid[[idx, 1]] = u[1];
let mapped = basis.map_point(&theta, u);
new_grid[[idx, 0]] = mapped[0];
new_grid[[idx, 1]] = mapped[1];
}
}
let (grid_phi, grid_jet) = evaluator.evaluate(grid.view())?;
if grid_phi.ncols() != m || grid_jet.dim() != (grid_rows, m, 2) {
return Err(format!(
"torus_isometry_flow_reparameterization: evaluator returned basis {:?} / jet {:?} on the audit grid; expected width {m}, latent_dim 2",
grid_phi.dim(),
grid_jet.dim()
));
}
let Some(recomposition) =
recompose_decoder_exact_ls(evaluator, decoder, grid_phi.view(), new_grid.view())?
else {
return Ok(None);
};
let mut new_row_coords = Array2::<f64>::zeros((n, 2));
for row in 0..n {
let mapped = basis.map_point(&theta, [row_coords[[row, 0]], row_coords[[row, 1]]]);
new_row_coords[[row, 0]] = mapped[0];
new_row_coords[[row, 1]] = mapped[1];
}
Ok(Some(TorusIsometryFlowReparameterization {
new_row_coords,
new_decoder: recomposition.new_decoder,
decoder_transport: recomposition.transport,
flow_theta: theta,
defect_initial,
defect_final: minimization.defect_final,
profiled_metric_scale: minimization.profiled_scale,
min_flow_jacobian_det,
recomposition_residual: recomposition.recomposition_residual,
}))
}
pub const PATCH_FLOW_MAX_DEGREE: usize = 1;
pub const PATCH_FLOW_DIFFEO_MIN_DET: f64 = SAE_FLOW_DIFFEO_MIN_DET;
pub const PATCH_FLOW_GUARD_NODES_PER_AXIS: usize = 48;
pub const PATCH_TRANSPORT_MIN_NODES_PER_AXIS: usize = 48;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PatchFlowModeKey {
pub component: usize,
pub exps: (usize, usize),
}
#[derive(Debug, Clone)]
pub struct FreePatchFlowBasis {
center: [f64; 2],
inv_half: [f64; 2],
exps: Vec<(usize, usize)>,
}
impl FreePatchFlowBasis {
pub fn new(lo: [f64; 2], hi: [f64; 2]) -> Result<Self, String> {
let mut center = [0.0_f64; 2];
let mut inv_half = [0.0_f64; 2];
for axis in 0..2 {
let span = hi[axis] - lo[axis];
let scale = lo[axis].abs().max(hi[axis].abs()).max(1.0);
if !(span.is_finite() && span > 1.0e-12 * scale) {
return Err(format!(
"FreePatchFlowBasis: patch axis {axis} has collapsed extent [{}, {}]",
lo[axis], hi[axis]
));
}
center[axis] = 0.5 * (lo[axis] + hi[axis]);
inv_half[axis] = 2.0 / span;
}
let mut exps = Vec::new();
for total in 1..=PATCH_FLOW_MAX_DEGREE {
for a in (0..=total).rev() {
let b = total - a;
exps.push((a, b));
}
}
Ok(Self {
center,
inv_half,
exps,
})
}
pub fn dim(&self) -> usize {
2 * self.exps.len()
}
pub fn mode_layout(&self) -> Vec<PatchFlowModeKey> {
let mut keys = Vec::with_capacity(self.dim());
for component in 0..2 {
for &exps in &self.exps {
keys.push(PatchFlowModeKey { component, exps });
}
}
keys
}
fn normalize(&self, t: [f64; 2]) -> [f64; 2] {
[
(t[0] - self.center[0]) * self.inv_half[0],
(t[1] - self.center[1]) * self.inv_half[1],
]
}
pub fn mode_samples(&self, t: [f64; 2]) -> Vec<FlowModeSample> {
let u = self.normalize(t);
let mut out = Vec::with_capacity(self.dim());
for component in 0..2 {
for &(a, b) in &self.exps {
let value = pow_u(u[0], a) * pow_u(u[1], b);
let du0 = if a == 0 {
0.0
} else {
a as f64 * pow_u(u[0], a - 1) * pow_u(u[1], b)
};
let du1 = if b == 0 {
0.0
} else {
b as f64 * pow_u(u[0], a) * pow_u(u[1], b - 1)
};
out.push(FlowModeSample {
component,
value,
grad: [du0 * self.inv_half[0], du1 * self.inv_half[1]],
});
}
}
out
}
pub fn map_point(&self, theta: &[f64], t: [f64; 2]) -> [f64; 2] {
assert_eq!(theta.len(), self.dim(), "FreePatchFlowBasis: theta length");
let mut out = t;
for (coef, sample) in theta.iter().zip(self.mode_samples(t)) {
out[sample.component] += coef * sample.value;
}
out
}
pub fn flow_jacobian(&self, theta: &[f64], t: [f64; 2]) -> [[f64; 2]; 2] {
assert_eq!(theta.len(), self.dim(), "FreePatchFlowBasis: theta length");
let mut jac = [[1.0, 0.0], [0.0, 1.0]];
for (coef, sample) in theta.iter().zip(self.mode_samples(t)) {
jac[sample.component][0] += coef * sample.grad[0];
jac[sample.component][1] += coef * sample.grad[1];
}
jac
}
pub fn min_jacobian_det_on_grid(&self, theta: &[f64]) -> f64 {
let nodes = PATCH_FLOW_GUARD_NODES_PER_AXIS;
let mut min_det = f64::INFINITY;
for i in 0..nodes {
for j in 0..nodes {
let u0 = -1.1 + 2.2 * i as f64 / (nodes - 1) as f64;
let u1 = -1.1 + 2.2 * j as f64 / (nodes - 1) as f64;
let t = [
self.center[0] + u0 / self.inv_half[0],
self.center[1] + u1 / self.inv_half[1],
];
let jac = self.flow_jacobian(theta, t);
let det = jac[0][0] * jac[1][1] - jac[0][1] * jac[1][0];
min_det = min_det.min(det);
}
}
min_det
}
}
fn pow_u(u: f64, k: usize) -> f64 {
let mut acc = 1.0_f64;
for _ in 0..k {
acc *= u;
}
acc
}
#[derive(Debug, Clone)]
pub struct PatchIsometryFlowReparameterization {
pub new_row_coords: Array2<f64>,
pub new_decoder: Array2<f64>,
pub decoder_transport: Array2<f64>,
pub flow_theta: Vec<f64>,
pub defect_initial: f64,
pub defect_final: f64,
pub profiled_metric_scale: f64,
pub min_flow_jacobian_det: f64,
pub recomposition_residual: f64,
}
pub fn patch_isometry_flow_reparameterization(
evaluator: &dyn SaeBasisEvaluator,
decoder: ArrayView2<'_, f64>,
row_coords: ArrayView2<'_, f64>,
) -> Result<Option<PatchIsometryFlowReparameterization>, String> {
let n = row_coords.nrows();
let m = decoder.nrows();
let p = decoder.ncols();
if row_coords.ncols() != 2 {
return Err(format!(
"patch_isometry_flow_reparameterization: expected (n, 2) row coordinates; got {:?}",
row_coords.dim()
));
}
if n == 0 || m == 0 || p == 0 {
return Ok(None);
}
let mut lo = [f64::INFINITY; 2];
let mut hi = [f64::NEG_INFINITY; 2];
for row in 0..n {
for axis in 0..2 {
let t = row_coords[[row, axis]];
if !t.is_finite() {
return Ok(None);
}
lo[axis] = lo[axis].min(t);
hi[axis] = hi[axis].max(t);
}
}
let Some((g_rows, g_bar)) = extract_pullback_metric_d2(
"patch_isometry_flow_reparameterization",
evaluator,
decoder,
row_coords,
)?
else {
return Ok(None);
};
let Some((ghat, ghat_norm_sq)) = flat_normalized_metric(&g_rows, g_bar) else {
return Ok(None);
};
let basis = match FreePatchFlowBasis::new(lo, hi) {
Ok(basis) => basis,
Err(_) => return Ok(None),
};
let q = basis.dim();
let mut row_modes: Vec<Vec<FlowModeSample>> = Vec::with_capacity(n);
for row in 0..n {
row_modes.push(basis.mode_samples([row_coords[[row, 0]], row_coords[[row, 1]]]));
}
let row_base = vec![[1.0_f64, 0.0, 0.0, 1.0]; n];
let Some(minimization) = minimize_isometry_defect_flow(
&row_modes,
&row_base,
&ghat,
ghat_norm_sq,
q,
PATCH_FLOW_DIFFEO_MIN_DET,
&|candidate: &[f64]| basis.min_jacobian_det_on_grid(candidate),
) else {
return Ok(None);
};
let theta = minimization.theta;
let defect_initial = minimization.defect_initial;
let min_flow_jacobian_det = basis.min_jacobian_det_on_grid(&theta);
if !(min_flow_jacobian_det > PATCH_FLOW_DIFFEO_MIN_DET) {
return Ok(None);
}
let axis_nodes = PATCH_TRANSPORT_MIN_NODES_PER_AXIS.max(3 * (m as f64).sqrt().ceil() as usize);
let grid_rows = axis_nodes * axis_nodes;
let mut grid = Array2::<f64>::zeros((grid_rows, 2));
let mut new_grid = Array2::<f64>::zeros((grid_rows, 2));
for i in 0..axis_nodes {
for j in 0..axis_nodes {
let idx = i * axis_nodes + j;
let u = [
lo[0] + (hi[0] - lo[0]) * i as f64 / (axis_nodes - 1) as f64,
lo[1] + (hi[1] - lo[1]) * j as f64 / (axis_nodes - 1) as f64,
];
grid[[idx, 0]] = u[0];
grid[[idx, 1]] = u[1];
let mapped = basis.map_point(&theta, u);
new_grid[[idx, 0]] = mapped[0];
new_grid[[idx, 1]] = mapped[1];
}
}
let (grid_phi, grid_jet) = evaluator.evaluate(grid.view())?;
if grid_phi.ncols() != m || grid_jet.dim() != (grid_rows, m, 2) {
return Err(format!(
"patch_isometry_flow_reparameterization: evaluator returned basis {:?} / jet {:?} on the audit grid; expected width {m}, latent_dim 2",
grid_phi.dim(),
grid_jet.dim()
));
}
let Some(recomposition) =
recompose_decoder_exact_ls(evaluator, decoder, grid_phi.view(), new_grid.view())?
else {
return Ok(None);
};
let mut new_row_coords = Array2::<f64>::zeros((n, 2));
for row in 0..n {
let mapped = basis.map_point(&theta, [row_coords[[row, 0]], row_coords[[row, 1]]]);
new_row_coords[[row, 0]] = mapped[0];
new_row_coords[[row, 1]] = mapped[1];
}
Ok(Some(PatchIsometryFlowReparameterization {
new_row_coords,
new_decoder: recomposition.new_decoder,
decoder_transport: recomposition.transport,
flow_theta: theta,
defect_initial,
defect_final: minimization.defect_final,
profiled_metric_scale: minimization.profiled_scale,
min_flow_jacobian_det,
recomposition_residual: recomposition.recomposition_residual,
}))
}
pub const SPHERE_FLOW_DIFFEO_MIN_DET: f64 = SAE_FLOW_DIFFEO_MIN_DET;
pub const SPHERE_FLOW_POLE_MARGIN: f64 = 0.20;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SphereBoostAxis {
Z,
X,
Y,
}
#[derive(Debug, Clone)]
pub struct SphereBoostFlowBasis;
impl SphereBoostFlowBasis {
pub fn dim(&self) -> usize {
3
}
pub fn mode_layout(&self) -> [SphereBoostAxis; 3] {
[SphereBoostAxis::Z, SphereBoostAxis::X, SphereBoostAxis::Y]
}
fn mode_displacements(t: [f64; 2]) -> [[f64; 2]; 3] {
let (lat, lon) = (t[0], t[1]);
let (s, c) = (lat.sin(), lat.cos());
let (cl, sl) = (lon.cos(), lon.sin());
[
[c, 0.0], [s * cl, -sl / c], [s * sl, cl / c], ]
}
fn mode_jacobians(t: [f64; 2]) -> [[[f64; 2]; 2]; 3] {
let (lat, lon) = (t[0], t[1]);
let (s, c) = (lat.sin(), lat.cos());
let (cl, sl) = (lon.cos(), lon.sin());
let c2 = c * c;
[
[[-s, 0.0], [0.0, 0.0]],
[[c * cl, -s * sl], [-sl * s / c2, -cl / c]],
[[c * sl, s * cl], [cl * s / c2, -sl / c]],
]
}
fn min_jacobian_det_on_band(theta: &[f64], lat_lo: f64, lat_hi: f64) -> f64 {
let nodes = 48usize;
let mut min_det = f64::INFINITY;
for i in 0..nodes {
let lat = lat_lo + (lat_hi - lat_lo) * i as f64 / (nodes - 1) as f64;
for j in 0..nodes {
let lon = -std::f64::consts::PI + std::f64::consts::TAU * j as f64 / nodes as f64;
let jac = Self::mode_jacobians([lat, lon]);
let mut a = [[1.0_f64, 0.0], [0.0, 1.0]];
for (k, dv) in jac.iter().enumerate() {
a[0][0] += theta[k] * dv[0][0];
a[0][1] += theta[k] * dv[0][1];
a[1][0] += theta[k] * dv[1][0];
a[1][1] += theta[k] * dv[1][1];
}
let det = a[0][0] * a[1][1] - a[0][1] * a[1][0];
min_det = min_det.min(det);
}
}
min_det
}
fn map_point(theta: &[f64], t: [f64; 2]) -> [f64; 2] {
let disp = Self::mode_displacements(t);
let mut out = t;
for (k, v) in disp.iter().enumerate() {
out[0] += theta[k] * v[0];
out[1] += theta[k] * v[1];
}
out
}
}
#[derive(Debug, Clone)]
pub struct SphereIsometryFlowReparameterization {
pub new_row_coords: Array2<f64>,
pub new_decoder: Array2<f64>,
pub decoder_transport: Array2<f64>,
pub flow_theta: Vec<f64>,
pub defect_initial: f64,
pub defect_final: f64,
pub min_flow_jacobian_det: f64,
pub recomposition_residual: f64,
}
pub fn sphere_isometry_flow_reparameterization(
evaluator: &dyn SaeBasisEvaluator,
decoder: ArrayView2<'_, f64>,
row_coords: ArrayView2<'_, f64>,
) -> Result<Option<SphereIsometryFlowReparameterization>, String> {
let n = row_coords.nrows();
let m = decoder.nrows();
let p = decoder.ncols();
if row_coords.ncols() != 2 {
return Err(format!(
"sphere_isometry_flow_reparameterization: expected (n, 2) row coordinates; got {:?}",
row_coords.dim()
));
}
if n == 0 || m == 0 || p == 0 {
return Ok(None);
}
for &t in row_coords.iter() {
if !t.is_finite() {
return Ok(None);
}
}
let mut lat_lo = f64::INFINITY;
let mut lat_hi = f64::NEG_INFINITY;
for row in 0..n {
let lat = row_coords[[row, 0]];
lat_lo = lat_lo.min(lat);
lat_hi = lat_hi.max(lat);
}
let pole = std::f64::consts::FRAC_PI_2 - SPHERE_FLOW_POLE_MARGIN;
if !(lat_lo > -pole && lat_hi < pole) {
return Ok(None);
}
let Some((g_rows, g_bar)) = extract_pullback_metric_d2(
"sphere_isometry_flow_reparameterization",
evaluator,
decoder,
row_coords,
)?
else {
return Ok(None);
};
let mut ghat: Vec<[f64; 3]> = Vec::with_capacity(n);
let mut ghat_norm_sq = 0.0_f64;
for g in &g_rows {
let h = [g[0] / g_bar, g[1] / g_bar, g[2] / g_bar];
ghat_norm_sq += h[0] * h[0] + h[1] * h[1] + 2.0 * h[2] * h[2];
ghat.push(h);
}
if !(ghat_norm_sq.is_finite() && ghat_norm_sq > 0.0) {
return Ok(None);
}
let q = 3usize;
let Some(minimization) =
sphere_minimize_boost_defect(&ghat, ghat_norm_sq, row_coords, q, lat_lo, lat_hi)
else {
return Ok(None);
};
let theta = minimization.theta;
let defect_initial = minimization.defect_initial;
let min_flow_jacobian_det =
SphereBoostFlowBasis::min_jacobian_det_on_band(&theta, lat_lo, lat_hi);
if !(min_flow_jacobian_det > SPHERE_FLOW_DIFFEO_MIN_DET) {
return Ok(None);
}
let axis_nodes = TORUS_TRANSPORT_MIN_NODES_PER_AXIS.max(3 * (m as f64).sqrt().ceil() as usize);
let grid_rows = axis_nodes * axis_nodes;
let mut grid = Array2::<f64>::zeros((grid_rows, 2));
let mut new_grid = Array2::<f64>::zeros((grid_rows, 2));
for i in 0..axis_nodes {
for j in 0..axis_nodes {
let idx = i * axis_nodes + j;
let lat = lat_lo + (lat_hi - lat_lo) * i as f64 / (axis_nodes - 1) as f64;
let lon = -std::f64::consts::PI + std::f64::consts::TAU * j as f64 / axis_nodes as f64;
grid[[idx, 0]] = lat;
grid[[idx, 1]] = lon;
let mapped = SphereBoostFlowBasis::map_point(&theta, [lat, lon]);
new_grid[[idx, 0]] = mapped[0];
new_grid[[idx, 1]] = mapped[1];
}
}
let (grid_phi, grid_jet) = evaluator.evaluate(grid.view())?;
if grid_phi.ncols() != m || grid_jet.dim() != (grid_rows, m, 2) {
return Err(format!(
"sphere_isometry_flow_reparameterization: evaluator returned basis {:?} / jet {:?} on the audit grid; expected width {m}, latent_dim 2",
grid_phi.dim(),
grid_jet.dim()
));
}
let Some(recomposition) =
recompose_decoder_exact_ls(evaluator, decoder, grid_phi.view(), new_grid.view())?
else {
return Ok(None);
};
let mut new_row_coords = Array2::<f64>::zeros((n, 2));
for row in 0..n {
let mapped =
SphereBoostFlowBasis::map_point(&theta, [row_coords[[row, 0]], row_coords[[row, 1]]]);
new_row_coords[[row, 0]] = mapped[0];
new_row_coords[[row, 1]] = mapped[1];
}
Ok(Some(SphereIsometryFlowReparameterization {
new_row_coords,
new_decoder: recomposition.new_decoder,
decoder_transport: recomposition.transport,
flow_theta: theta,
defect_initial,
defect_final: minimization.defect_final,
min_flow_jacobian_det,
recomposition_residual: recomposition.recomposition_residual,
}))
}
struct SphereFlowMinimization {
theta: Vec<f64>,
defect_initial: f64,
defect_final: f64,
}
fn sphere_eval_boost_defect(
theta: &[f64],
row_coords: ArrayView2<'_, f64>,
ghat: &[[f64; 3]],
ghat_norm_sq: f64,
) -> Option<FlowObjectiveState> {
let n = row_coords.nrows();
let mut a_rows: Vec<[f64; 4]> = Vec::with_capacity(n);
let mut cross = 0.0_f64;
for row in 0..n {
let t = [row_coords[[row, 0]], row_coords[[row, 1]]];
let jac = SphereBoostFlowBasis::mode_jacobians(t);
let mut dphi = [[1.0_f64, 0.0], [0.0, 1.0]];
for (k, dv) in jac.iter().enumerate() {
dphi[0][0] += theta[k] * dv[0][0];
dphi[0][1] += theta[k] * dv[0][1];
dphi[1][0] += theta[k] * dv[1][0];
dphi[1][1] += theta[k] * dv[1][1];
}
let mapped = SphereBoostFlowBasis::map_point(theta, t);
let cos_lat_new = mapped[0].cos();
const SPHERE_EVAL_COS_FLOOR: f64 = 1.0e-6;
if !(cos_lat_new.is_finite() && cos_lat_new > SPHERE_EVAL_COS_FLOOR) {
return None;
}
let a = [
dphi[0][0],
dphi[0][1],
cos_lat_new * dphi[1][0],
cos_lat_new * dphi[1][1],
];
a_rows.push(a);
}
for (a, g) in a_rows.iter().zip(ghat.iter()) {
let m00 = a[0] * a[0] + a[2] * a[2];
let m11 = a[1] * a[1] + a[3] * a[3];
let m01 = a[0] * a[1] + a[2] * a[3];
cross += m00 * g[0] + m11 * g[1] + 2.0 * m01 * g[2];
}
let scale = cross / ghat_norm_sq;
if !(scale.is_finite() && scale > 0.0) {
return None;
}
let mut defect = 0.0_f64;
for (a, g) in a_rows.iter().zip(ghat.iter()) {
let m00 = a[0] * a[0] + a[2] * a[2];
let m11 = a[1] * a[1] + a[3] * a[3];
let m01 = a[0] * a[1] + a[2] * a[3];
let r00 = m00 - scale * g[0];
let r11 = m11 - scale * g[1];
let r01 = m01 - scale * g[2];
defect += r00 * r00 + r11 * r11 + 2.0 * r01 * r01;
}
if !defect.is_finite() {
return None;
}
Some(FlowObjectiveState {
defect,
scale,
a_rows,
})
}
fn sphere_minimize_boost_defect(
ghat: &[[f64; 3]],
ghat_norm_sq: f64,
row_coords: ArrayView2<'_, f64>,
q: usize,
lat_lo: f64,
lat_hi: f64,
) -> Option<SphereFlowMinimization> {
let n = row_coords.nrows();
let mut theta = vec![0.0_f64; q];
let mut state = sphere_eval_boost_defect(&theta, row_coords, ghat, ghat_norm_sq)?;
let defect_initial = state.defect;
if !(defect_initial > 0.0) {
return None;
}
let sqrt2 = std::f64::consts::SQRT_2;
let fd_h = 1.0e-6_f64; let mut lambda = 1.0e-4_f64;
let mut any_accepted = false;
for iteration in 0..TORUS_FLOW_GN_MAX_ITERS {
if iteration + 1 == TORUS_FLOW_GN_MAX_ITERS {
break;
}
let mut jmat = Array2::<f64>::zeros((3 * n, q));
let mut rcol = Array2::<f64>::zeros((3 * n, 1));
let scale = state.scale;
for (i, (a, g)) in state.a_rows.iter().zip(ghat.iter()).enumerate() {
let m00 = a[0] * a[0] + a[2] * a[2];
let m11 = a[1] * a[1] + a[3] * a[3];
let m01 = a[0] * a[1] + a[2] * a[3];
rcol[[3 * i, 0]] = m00 - scale * g[0];
rcol[[3 * i + 1, 0]] = m11 - scale * g[1];
rcol[[3 * i + 2, 0]] = sqrt2 * (m01 - scale * g[2]);
}
for k in 0..q {
let mut tp = theta.clone();
let mut tm = theta.clone();
tp[k] += fd_h; tm[k] -= fd_h; let sp = sphere_eval_boost_defect(&tp, row_coords, ghat, ghat_norm_sq);
let sm = sphere_eval_boost_defect(&tm, row_coords, ghat, ghat_norm_sq);
let (Some(sp), Some(sm)) = (sp, sm) else {
return if any_accepted {
Some(SphereFlowMinimization {
theta,
defect_initial,
defect_final: state.defect,
})
} else {
None
};
};
for (i, (ap, am)) in sp.a_rows.iter().zip(sm.a_rows.iter()).enumerate() {
let mp00 = ap[0] * ap[0] + ap[2] * ap[2] - scale * ghat[i][0];
let mp11 = ap[1] * ap[1] + ap[3] * ap[3] - scale * ghat[i][1];
let mp01 = ap[0] * ap[1] + ap[2] * ap[3] - scale * ghat[i][2];
let mm00 = am[0] * am[0] + am[2] * am[2] - scale * ghat[i][0];
let mm11 = am[1] * am[1] + am[3] * am[3] - scale * ghat[i][1];
let mm01 = am[0] * am[1] + am[2] * am[3] - scale * ghat[i][2];
jmat[[3 * i, k]] = (mp00 - mm00) / (2.0 * fd_h); jmat[[3 * i + 1, k]] = (mp11 - mm11) / (2.0 * fd_h); jmat[[3 * i + 2, k]] = sqrt2 * (mp01 - mm01) / (2.0 * fd_h); }
}
let jtj = fast_ata(&jmat);
let jtr = fast_atb(&jmat, &rcol);
let mut rejects = 0usize;
let mut accepted_step = false;
let mut converged = false;
let mut step_norm_sq = 0.0_f64;
while rejects < TORUS_FLOW_GN_MAX_REJECTS {
let mut damped = jtj.clone();
for d in 0..q {
damped[[d, d]] += lambda * (1.0 + jtj[[d, d]]);
}
let factor = match damped.cholesky(FaerSide::Lower) {
Ok(factor) => factor,
Err(_) => {
lambda *= 10.0;
rejects += 1;
continue;
}
};
let mut neg_jtr = jtr.clone();
neg_jtr.mapv_inplace(|v| -v);
let delta = factor.solve_mat(&neg_jtr);
let mut candidate = theta.clone();
step_norm_sq = 0.0;
for k in 0..q {
candidate[k] += delta[[k, 0]];
step_norm_sq += delta[[k, 0]] * delta[[k, 0]];
}
let folded = SphereBoostFlowBasis::min_jacobian_det_on_band(&candidate, lat_lo, lat_hi)
<= SPHERE_FLOW_DIFFEO_MIN_DET;
let candidate_state = if folded {
None
} else {
sphere_eval_boost_defect(&candidate, row_coords, ghat, ghat_norm_sq)
};
match candidate_state {
Some(next) if next.defect < state.defect => {
let improvement = state.defect - next.defect;
theta = candidate;
state = next;
any_accepted = true;
accepted_step = true;
lambda = (lambda / 10.0).max(1.0e-12);
if improvement <= 1.0e-14 * (1.0 + state.defect) {
converged = true;
}
break;
}
Some(..) | None => {
lambda *= 10.0;
rejects += 1;
}
}
}
if !accepted_step || converged {
break;
}
let theta_norm_sq: f64 = theta.iter().map(|v| v * v).sum();
if step_norm_sq <= 1.0e-24 * (1.0 + theta_norm_sq) {
break;
}
}
if !any_accepted || !(state.defect < defect_initial) {
return None;
}
Some(SphereFlowMinimization {
theta,
defect_initial,
defect_final: state.defect,
})
}
pub fn sphere_chart_isometry_defect(
evaluator: &dyn SaeBasisEvaluator,
decoder: ArrayView2<'_, f64>,
row_coords: ArrayView2<'_, f64>,
) -> Result<Option<f64>, String> {
let n = row_coords.nrows();
let m = decoder.nrows();
let p = decoder.ncols();
if row_coords.ncols() != 2 {
return Err(format!(
"sphere_chart_isometry_defect: expected (n, 2) row coordinates; got {:?}",
row_coords.dim()
));
}
if n == 0 || m == 0 || p == 0 {
return Ok(None);
}
for &t in row_coords.iter() {
if !t.is_finite() {
return Ok(None);
}
}
let Some((g_rows, g_bar)) = extract_pullback_metric_d2(
"sphere_chart_isometry_defect",
evaluator,
decoder,
row_coords,
)?
else {
return Ok(None);
};
let mut ghat: Vec<[f64; 3]> = Vec::with_capacity(n);
let mut gref: Vec<[f64; 3]> = Vec::with_capacity(n);
let mut gref_norm_sq = 0.0_f64;
let mut cross = 0.0_f64;
for (row, g) in g_rows.iter().enumerate() {
let lat = row_coords[[row, 0]];
let cos_lat = lat.cos();
let r11 = cos_lat * cos_lat;
const POLE_COS2_FLOOR: f64 = 1e-12;
if !(r11.is_finite() && r11 > POLE_COS2_FLOOR) {
return Ok(None);
}
let h = [g[0] / g_bar, g[1] / g_bar, g[2] / g_bar];
let r = [1.0_f64, r11, 0.0_f64];
cross += h[0] * r[0] + h[1] * r[1] + 2.0 * h[2] * r[2];
gref_norm_sq += r[0] * r[0] + r[1] * r[1] + 2.0 * r[2] * r[2];
ghat.push(h);
gref.push(r);
}
if !(gref_norm_sq.is_finite() && gref_norm_sq > 0.0) {
return Ok(None);
}
let c = cross / gref_norm_sq;
if !(c.is_finite() && c > 0.0) {
return Ok(None);
}
let mut defect = 0.0_f64;
for (h, r) in ghat.iter().zip(gref.iter()) {
let r00 = h[0] - c * r[0];
let r11 = h[1] - c * r[1];
let r01 = h[2] - c * r[2];
defect += r00 * r00 + r11 * r11 + 2.0 * r01 * r01;
}
if !defect.is_finite() {
return Ok(None);
}
Ok(Some(defect))
}
pub fn d1_atom_fitted_turning(
evaluator: &dyn SaeBasisEvaluator,
decoder: ArrayView2<'_, f64>,
row_coords: ArrayView1<'_, f64>,
) -> Result<Option<f64>, String> {
let m = decoder.nrows();
let p = decoder.ncols();
if m == 0 || p == 0 || row_coords.is_empty() {
return Ok(None);
}
let mut lo = f64::INFINITY;
let mut hi = f64::NEG_INFINITY;
for &t in row_coords.iter() {
if !t.is_finite() {
return Ok(None);
}
lo = lo.min(t);
hi = hi.max(t);
}
if !(hi > lo) {
return Ok(None);
}
let cells = TURNING_QUADRATURE_CELLS;
let nodes = 2 * cells + 1; let h = (hi - lo) / (nodes - 1) as f64;
let mut grid = Array2::<f64>::zeros((nodes, 1));
for (i, mut row) in grid.outer_iter_mut().enumerate() {
row[0] = lo + h * i as f64;
}
let (_phi, jet) = evaluator.evaluate(grid.view())?;
if jet.dim() != (nodes, m, 1) {
return Err(format!(
"d1_atom_fitted_turning: evaluator returned jet {:?}; expected ({nodes}, {m}, 1)",
jet.dim()
));
}
let Some(hess_result) = evaluator.second_jet_dyn(grid.view()) else {
return Ok(None);
};
let hess = hess_result?;
if hess.dim() != (nodes, m, 1, 1) {
return Err(format!(
"d1_atom_fitted_turning: second_jet returned {:?}; expected ({nodes}, {m}, 1, 1)",
hess.dim()
));
}
let mut integrand = vec![0.0_f64; nodes];
let mut g1 = vec![0.0_f64; p];
let mut g2 = vec![0.0_f64; p];
for node in 0..nodes {
for slot in g1.iter_mut() {
*slot = 0.0;
}
for slot in g2.iter_mut() {
*slot = 0.0;
}
for bm in 0..m {
let d1 = jet[[node, bm, 0]];
let d2 = hess[[node, bm, 0, 0]];
if d1 == 0.0 && d2 == 0.0 {
continue;
}
for j in 0..p {
let b = decoder[[bm, j]];
g1[j] += d1 * b;
g2[j] += d2 * b;
}
}
let mut n1 = 0.0_f64; let mut n2 = 0.0_f64; let mut dot = 0.0_f64; for j in 0..p {
n1 += g1[j] * g1[j];
n2 += g2[j] * g2[j];
dot += g1[j] * g2[j];
}
if !(n1 > 0.0) {
return Ok(None);
}
let raw_wedge_sq = n1 * n2 - dot * dot;
let roundoff_floor = 64.0 * f64::EPSILON * (n1 * n2).abs().max(dot.abs() * dot.abs());
let wedge_sq = if raw_wedge_sq <= roundoff_floor {
0.0
} else {
raw_wedge_sq
};
integrand[node] = wedge_sq.sqrt() / n1;
if !integrand[node].is_finite() {
return Ok(None);
}
}
let mut theta = 0.0_f64;
for cell in 0..cells {
let f0 = integrand[2 * cell];
let fm = integrand[2 * cell + 1];
let f1 = integrand[2 * cell + 2];
theta += h / 3.0 * (f0 + 4.0 * fm + f1);
}
if !(theta.is_finite() && theta >= 0.0) {
return Ok(None);
}
Ok(Some(theta))
}
#[cfg(test)]
mod patch_flow_tests {
use super::*;
use ndarray::{Array2, Array3, ArrayView2};
#[derive(Debug)]
struct MockPatchEvaluator;
impl SaeBasisEvaluator for MockPatchEvaluator {
fn evaluate(
&self,
coords: ArrayView2<'_, f64>,
) -> Result<(Array2<f64>, Array3<f64>), String> {
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 3));
let mut jet = Array3::<f64>::zeros((n, 3, 2));
for row in 0..n {
let t0 = coords[[row, 0]];
let t1 = coords[[row, 1]];
phi[[row, 0]] = 1.0;
phi[[row, 1]] = t0;
phi[[row, 2]] = t1;
jet[[row, 1, 0]] = 1.0;
jet[[row, 2, 1]] = 1.0;
}
Ok((phi, jet))
}
fn second_jet_dyn(
&self,
coords: ArrayView2<'_, f64>,
) -> Option<Result<ndarray::Array4<f64>, String>> {
if coords.ncols() != 2 {
return Some(Err(format!(
"MockPatchEvaluator::second_jet_dyn: expected 2 cols, got {}",
coords.ncols()
)));
}
Some(Ok(ndarray::Array4::<f64>::zeros((coords.nrows(), 3, 2, 2))))
}
fn third_jet_dyn(
&self,
coords: ArrayView2<'_, f64>,
) -> Option<Result<ndarray::Array5<f64>, String>> {
if coords.ncols() != 2 {
return Some(Err(format!(
"MockPatchEvaluator::third_jet_dyn: expected 2 cols, got {}",
coords.ncols()
)));
}
Some(Ok(ndarray::Array5::<f64>::zeros((
coords.nrows(),
3,
2,
2,
2,
))))
}
}
fn warp_decoder(m: [[f64; 2]; 2]) -> Array2<f64> {
let mut b = Array2::<f64>::zeros((3, 2));
b[[1, 0]] = m[0][0];
b[[1, 1]] = m[1][0];
b[[2, 0]] = m[0][1];
b[[2, 1]] = m[1][1];
b
}
fn patch_coords() -> Array2<f64> {
let g = 9usize;
let mut c = Array2::<f64>::zeros((g * g, 2));
for i in 0..g {
for j in 0..g {
let row = i * g + j;
c[[row, 0]] = i as f64 / (g - 1) as f64;
c[[row, 1]] = j as f64 / (g - 1) as f64;
}
}
c
}
fn flat_defect_of_constant_metric(g: [f64; 3], n: usize) -> f64 {
let det = g[0] * g[1] - g[2] * g[2];
let g_bar = det.sqrt();
let h = [g[0] / g_bar, g[1] / g_bar, g[2] / g_bar];
let c = 0.5 * (h[0] + h[1]);
let r00 = h[0] - c;
let r11 = h[1] - c;
let r01 = h[2];
n as f64 * (r00 * r00 + r11 * r11 + 2.0 * r01 * r01)
}
#[test]
fn planted_warped_patch_recovers_uniform_speed_coords() {
let m = [[1.6, 0.5], [0.0, 0.8]];
let ev = MockPatchEvaluator;
let decoder = warp_decoder(m);
let coords = patch_coords();
let n = coords.nrows();
let g00 = m[0][0] * m[0][0] + m[1][0] * m[1][0];
let g11 = m[0][1] * m[0][1] + m[1][1] * m[1][1];
let g01 = m[0][0] * m[0][1] + m[1][0] * m[1][1];
let defect_initial = flat_defect_of_constant_metric([g00, g11, g01], n);
assert!(
defect_initial > 1e-2,
"the planted anisotropic warp must start with a sizeable defect; got {defect_initial:.3e}"
);
let repar = patch_isometry_flow_reparameterization(&ev, decoder.view(), coords.view())
.expect("patch reparameterization must evaluate")
.expect("a warped patch with a global flow basis must canonicalize");
assert!(
repar.defect_final <= 0.10 * defect_initial,
"canonicalization must drive the anisotropy defect to within 10% of the optimum; \
initial {defect_initial:.3e}, final {:.3e}",
repar.defect_final
);
assert!(
repar.defect_final < repar.defect_initial,
"the pass must report a strict improvement; initial {:.3e}, final {:.3e}",
repar.defect_initial,
repar.defect_final
);
assert!(
repar.min_flow_jacobian_det > PATCH_FLOW_DIFFEO_MIN_DET,
"the canonical flow must be fold-free; min det {:.3e}",
repar.min_flow_jacobian_det
);
assert!(
repar.recomposition_residual <= CHART_RECOMPOSITION_REL_TOL,
"the decoded image must be reproduced within the recomposition tolerance; got {:.3e}",
repar.recomposition_residual
);
}
#[test]
fn already_uniform_patch_is_left_as_fitted() {
let ev = MockPatchEvaluator;
let decoder = warp_decoder([[1.0, 0.0], [0.0, 1.0]]);
let coords = patch_coords();
let out = patch_isometry_flow_reparameterization(&ev, decoder.view(), coords.view())
.expect("patch reparameterization must evaluate");
assert!(
out.is_none(),
"an already-uniform patch chart must be left as fitted (honest skip), got Some"
);
}
#[test]
fn collapsed_patch_axis_is_refused() {
let ev = MockPatchEvaluator;
let decoder = warp_decoder([[1.3, 0.0], [0.0, 0.9]]);
let mut coords = patch_coords();
for row in 0..coords.nrows() {
coords[[row, 1]] = 0.5;
}
let out = patch_isometry_flow_reparameterization(&ev, decoder.view(), coords.view())
.expect("patch reparameterization must evaluate");
assert!(
out.is_none(),
"a patch collapsed along one axis must be refused (None), got Some"
);
}
#[test]
fn free_patch_flow_basis_layout_and_jacobian_at_identity() {
let basis = FreePatchFlowBasis::new([0.0, 0.0], [1.0, 1.0]).expect("patch basis");
assert_eq!(basis.dim(), 4);
assert_eq!(basis.mode_layout().len(), 4);
let theta = vec![0.0_f64; basis.dim()];
let det = basis.min_jacobian_det_on_grid(&theta);
assert!(
(det - 1.0).abs() < 1e-12,
"identity flow has det Dφ ≡ 1; got {det}"
);
let layout = basis.mode_layout();
assert_eq!(layout[0].component, 0);
assert_eq!(layout[0].exps, (1, 0));
assert_eq!(layout[1].exps, (0, 1));
}
#[test]
fn free_patch_mode_gradients_match_finite_difference() {
let basis = FreePatchFlowBasis::new([-0.5, 0.2], [1.5, 2.2]).expect("patch basis");
let t = [0.3, 1.1];
let eps = 1e-6;
let base = basis.mode_samples(t);
let plus0 = basis.mode_samples([t[0] + eps, t[1]]);
let plus1 = basis.mode_samples([t[0], t[1] + eps]);
for k in 0..base.len() {
let fd0 = (plus0[k].value - base[k].value) / eps;
let fd1 = (plus1[k].value - base[k].value) / eps;
assert!(
(fd0 - base[k].grad[0]).abs() < 1e-4,
"mode {k} ∂/∂t₀ FD {fd0} vs analytic {}",
base[k].grad[0]
);
assert!(
(fd1 - base[k].grad[1]).abs() < 1e-4,
"mode {k} ∂/∂t₁ FD {fd1} vs analytic {}",
base[k].grad[1]
);
}
}
}
#[cfg(test)]
mod sphere_defect_tests {
use super::*;
use ndarray::{Array2, Array3, ArrayView2};
#[derive(Debug)]
struct MockSphereEvaluator {
warp: f64,
}
impl SaeBasisEvaluator for MockSphereEvaluator {
fn evaluate(
&self,
coords: ArrayView2<'_, f64>,
) -> Result<(Array2<f64>, Array3<f64>), String> {
let n = coords.nrows();
let phi = Array2::<f64>::zeros((n, 2));
let mut jet = Array3::<f64>::zeros((n, 2, 2));
for row in 0..n {
let lat = coords[[row, 0]];
jet[[row, 0, 0]] = 1.0; jet[[row, 1, 1]] = self.warp * lat.cos(); }
Ok((phi, jet))
}
fn second_jet_dyn(
&self,
coords: ArrayView2<'_, f64>,
) -> Option<Result<ndarray::Array4<f64>, String>> {
if coords.ncols() != 2 {
return Some(Err(format!(
"MockSphereEvaluator::second_jet_dyn: expected (lat, lon) coords, got {} cols",
coords.ncols()
)));
}
None
}
fn third_jet_dyn(
&self,
coords: ArrayView2<'_, f64>,
) -> Option<Result<ndarray::Array5<f64>, String>> {
if coords.ncols() != 2 {
return Some(Err(format!(
"MockSphereEvaluator::third_jet_dyn: expected (lat, lon) coords, got {} cols",
coords.ncols()
)));
}
None
}
}
fn coords(lats: &[f64]) -> Array2<f64> {
let n = lats.len();
let mut c = Array2::<f64>::zeros((n, 2));
for (i, &lat) in lats.iter().enumerate() {
c[[i, 0]] = lat;
c[[i, 1]] = 0.1 * i as f64; }
c
}
#[test]
fn round_isometric_chart_has_zero_defect() {
let ev = MockSphereEvaluator { warp: 1.0 };
let decoder = Array2::<f64>::eye(2);
let c = coords(&[-0.6, -0.2, 0.0, 0.3, 0.7]);
let defect = sphere_chart_isometry_defect(&ev, decoder.view(), c.view())
.expect("defect must evaluate")
.expect("non-degenerate round chart must return Some");
assert!(
defect < 1e-10,
"a chart whose pullback metric is exactly diag(1, cos²lat) is round-isometric; \
defect should be ~0, got {defect:.3e}"
);
}
#[test]
fn warped_chart_has_large_defect() {
let ev = MockSphereEvaluator { warp: 2.5 };
let decoder = Array2::<f64>::eye(2);
let c = coords(&[-0.6, -0.2, 0.0, 0.3, 0.7]);
let defect = sphere_chart_isometry_defect(&ev, decoder.view(), c.view())
.expect("defect must evaluate")
.expect("non-degenerate warped chart must return Some");
assert!(
defect > 1e-2,
"an anisotropically warped chart must register a sizeable defect, got {defect:.3e}"
);
}
#[test]
fn pole_singularity_is_refused_not_fabricated() {
let ev = MockSphereEvaluator { warp: 1.0 };
let decoder = Array2::<f64>::eye(2);
let base = coords(&[0.0, 0.3]);
let mut c3 = Array2::<f64>::zeros((3, 2));
c3.slice_mut(ndarray::s![0..2, ..]).assign(&base);
c3[[2, 0]] = std::f64::consts::FRAC_PI_2;
let out = sphere_chart_isometry_defect(&ev, decoder.view(), c3.view())
.expect("defect must evaluate");
assert!(
out.is_none(),
"a pole-singular chart row must be refused, got {out:?}"
);
}
#[test]
fn sphere_boost_mode_jacobians_match_finite_difference() {
let eps = 1e-6;
for t in [[0.3, 0.7], [-0.6, -1.2]] {
let jac = SphereBoostFlowBasis::mode_jacobians(t);
let vp0 = SphereBoostFlowBasis::mode_displacements([t[0] + eps, t[1]]);
let vm0 = SphereBoostFlowBasis::mode_displacements([t[0] - eps, t[1]]);
let vp1 = SphereBoostFlowBasis::mode_displacements([t[0], t[1] + eps]);
let vm1 = SphereBoostFlowBasis::mode_displacements([t[0], t[1] - eps]);
for k in 0..3 {
for comp in 0..2 {
let fd_dlat = (vp0[k][comp] - vm0[k][comp]) / (2.0 * eps);
let fd_dlon = (vp1[k][comp] - vm1[k][comp]) / (2.0 * eps);
assert!(
(fd_dlat - jac[k][comp][0]).abs() < 1e-5,
"boost {k} comp {comp} ∂/∂lat FD {fd_dlat} vs analytic {} at {t:?}",
jac[k][comp][0]
);
assert!(
(fd_dlon - jac[k][comp][1]).abs() < 1e-5,
"boost {k} comp {comp} ∂/∂lon FD {fd_dlon} vs analytic {} at {t:?}",
jac[k][comp][1]
);
}
}
}
}
#[test]
fn sphere_boost_layout_and_zonal_is_pole_free() {
let basis = SphereBoostFlowBasis;
assert_eq!(basis.dim(), 3);
assert_eq!(
basis.mode_layout(),
[SphereBoostAxis::Z, SphereBoostAxis::X, SphereBoostAxis::Y]
);
for lat in [-1.4, -0.3, 0.0, 0.9, 1.4] {
let disp = SphereBoostFlowBasis::mode_displacements([lat, 0.5]);
assert!(disp[0][0].is_finite() && disp[0][1] == 0.0);
}
}
}
#[cfg(test)]
mod turning_tests {
use super::*;
use crate::terms::sae::basis::PeriodicHarmonicEvaluator;
use ndarray::{Array1, Array2};
use std::f64::consts::TAU;
#[test]
fn full_circle_turning_is_two_pi() {
let ev = PeriodicHarmonicEvaluator::new(3).expect("3-basis circle");
let mut decoder = Array2::<f64>::zeros((3, 2));
decoder[[2, 0]] = 1.0; decoder[[1, 1]] = 1.0; let coords = Array1::from_iter((0..=50).map(|i| i as f64 / 50.0));
let theta = d1_atom_fitted_turning(&ev, decoder.view(), coords.view())
.expect("turning must evaluate")
.expect("a non-degenerate circle must return Some");
assert!(
(theta - TAU).abs() < 1e-6,
"a full unit circle has total turning 2π; got {theta:.9}"
);
}
#[test]
fn half_circle_turning_is_pi() {
let ev = PeriodicHarmonicEvaluator::new(3).expect("3-basis circle");
let mut decoder = Array2::<f64>::zeros((3, 2));
decoder[[2, 0]] = 1.0;
decoder[[1, 1]] = 1.0;
let coords = Array1::from_iter((0..=25).map(|i| 0.5 * i as f64 / 25.0));
let theta = d1_atom_fitted_turning(&ev, decoder.view(), coords.view())
.expect("turning must evaluate")
.expect("a non-degenerate half-circle must return Some");
assert!(
(theta - std::f64::consts::PI).abs() < 1e-6,
"a half circle turns through π; got {theta:.9}"
);
}
#[test]
fn straight_line_image_has_zero_turning() {
let ev = PeriodicHarmonicEvaluator::new(3).expect("3-basis circle");
let mut decoder = Array2::<f64>::zeros((3, 2));
decoder[[2, 0]] = 1.0;
decoder[[2, 1]] = 2.0;
let coords = Array1::from_iter((0..=20).map(|i| 0.05 + 0.15 * i as f64 / 20.0));
let theta = d1_atom_fitted_turning(&ev, decoder.view(), coords.view())
.expect("turning must evaluate")
.expect("a non-degenerate segment must return Some");
assert!(
theta < 1e-9,
"a straight-line image has zero turning (the linear-tail signature); got {theta:.3e}"
);
}
}