use faer::Side as FaerSide;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::linalg::faer_ndarray::{FaerCholesky, fast_ab, fast_ata, fast_atb};
use crate::terms::sae_manifold::{SaeBasisEvaluator, solve_design_least_squares};
pub const ARC_LENGTH_GRID_CELLS: usize = 2048;
pub const CHART_RECOMPOSITION_REL_TOL: f64 = 1.0e-9;
#[derive(Debug, Clone, PartialEq)]
pub enum CanonicalChartTopology {
Circle { period: f64 },
Interval,
}
#[derive(Debug, Clone)]
pub struct UnitSpeedReparameterization {
pub new_row_coords: Array1<f64>,
pub new_decoder: Array2<f64>,
pub decoder_transport: Array2<f64>,
pub total_arc_length: f64,
pub recomposition_residual: f64,
}
fn curve_speeds(
jet: &ndarray::Array3<f64>,
decoder: ArrayView2<'_, f64>,
) -> Result<Vec<f64>, String> {
let (rows, m, d) = jet.dim();
if d != 1 {
return Err(format!(
"sae_chart_canonicalization: expected a 1-D latent jet, got latent_dim {d}"
));
}
if decoder.nrows() != m {
return Err(format!(
"sae_chart_canonicalization: jet basis width {m} != decoder rows {}",
decoder.nrows()
));
}
let p = decoder.ncols();
let mut speeds = Vec::with_capacity(rows);
let mut tangent = vec![0.0_f64; p];
for row in 0..rows {
for slot in tangent.iter_mut() {
*slot = 0.0;
}
for bm in 0..m {
let dphi = jet[[row, bm, 0]];
if dphi == 0.0 {
continue;
}
for (j, slot) in tangent.iter_mut().enumerate() {
*slot += dphi * decoder[[bm, j]];
}
}
speeds.push(tangent.iter().map(|v| v * v).sum::<f64>().sqrt());
}
Ok(speeds)
}
fn partial_cell_arc(f0: f64, fm: f64, f1: f64, h: f64, x: f64) -> f64 {
if h <= 0.0 {
return 0.0;
}
let a = (2.0 * f0 - 4.0 * fm + 2.0 * f1) / (h * h);
let b = (-3.0 * f0 + 4.0 * fm - f1) / h;
let x2 = x * x;
a * x2 * x / 3.0 + b * x2 / 2.0 + f0 * x
}
pub fn unit_speed_reparameterization(
evaluator: &dyn SaeBasisEvaluator,
decoder: ArrayView2<'_, f64>,
row_coords: ArrayView1<'_, f64>,
topology: &CanonicalChartTopology,
) -> Result<Option<UnitSpeedReparameterization>, String> {
let n = row_coords.len();
let m = decoder.nrows();
let p = decoder.ncols();
if n == 0 || m == 0 || p == 0 {
return Ok(None);
}
for &t in row_coords.iter() {
if !t.is_finite() {
return Ok(None);
}
}
let (lo, hi, span) = match topology {
CanonicalChartTopology::Circle { period } => {
if !(period.is_finite() && *period > 0.0) {
return Err(format!(
"sae_chart_canonicalization: circle period must be finite and positive; got {period}"
));
}
(0.0, *period, *period)
}
CanonicalChartTopology::Interval => {
let mut t_min = f64::INFINITY;
let mut t_max = f64::NEG_INFINITY;
for &t in row_coords.iter() {
t_min = t_min.min(t);
t_max = t_max.max(t);
}
let scale = t_min.abs().max(t_max.abs()).max(1.0);
if !(t_max - t_min > 1.0e-12 * scale) {
return Ok(None);
}
(t_min, t_max, 1.0)
}
};
let cells = ARC_LENGTH_GRID_CELLS;
let h = (hi - lo) / cells as f64;
let mut quad_coords = Array2::<f64>::zeros((2 * cells + 1, 1));
for j in 0..=cells {
quad_coords[[2 * j, 0]] = lo + j as f64 * h;
if j < cells {
quad_coords[[2 * j + 1, 0]] = lo + (j as f64 + 0.5) * h;
}
}
let (grid_phi_all, grid_jet_all) = evaluator.evaluate(quad_coords.view())?;
if grid_phi_all.ncols() != m {
return Err(format!(
"sae_chart_canonicalization: evaluator basis width {} != decoder rows {m}",
grid_phi_all.ncols()
));
}
let speeds = curve_speeds(&grid_jet_all, decoder)?;
if speeds.iter().any(|s| !s.is_finite()) {
return Ok(None);
}
let mut cumulative = vec![0.0_f64; cells + 1];
for j in 0..cells {
let f0 = speeds[2 * j];
let fm = speeds[2 * j + 1];
let f1 = speeds[2 * j + 2];
cumulative[j + 1] = cumulative[j] + h * (f0 + 4.0 * fm + f1) / 6.0;
}
let total = cumulative[cells];
if !(total.is_finite() && total > 0.0) {
return Ok(None);
}
let rescale = span / total;
let map_coord = |t: f64| -> f64 {
let local = match topology {
CanonicalChartTopology::Circle { period } => (t - lo).rem_euclid(*period),
CanonicalChartTopology::Interval => (t - lo).clamp(0.0, hi - lo),
};
let cell = ((local / h).floor() as usize).min(cells - 1);
let x = local - cell as f64 * h;
let s = cumulative[cell]
+ partial_cell_arc(
speeds[2 * cell],
speeds[2 * cell + 1],
speeds[2 * cell + 2],
h,
x,
);
let mapped = rescale * s;
match topology {
CanonicalChartTopology::Circle { period } => mapped.rem_euclid(*period),
CanonicalChartTopology::Interval => mapped.clamp(0.0, span),
}
};
let new_row_coords = Array1::from_iter(row_coords.iter().map(|&t| map_coord(t)));
let mut node_new_coords = Array2::<f64>::zeros((cells + 1, 1));
let mut old_phi = Array2::<f64>::zeros((cells + 1, m));
for j in 0..=cells {
node_new_coords[[j, 0]] = map_coord(lo + j as f64 * h);
for bm in 0..m {
old_phi[[j, bm]] = grid_phi_all[[2 * j, bm]];
}
}
let Some(recomposition) =
recompose_decoder_exact_ls(evaluator, decoder, old_phi.view(), node_new_coords.view())?
else {
return Ok(None);
};
Ok(Some(UnitSpeedReparameterization {
new_row_coords,
new_decoder: recomposition.new_decoder,
decoder_transport: recomposition.transport,
total_arc_length: total,
recomposition_residual: recomposition.recomposition_residual,
}))
}
pub(crate) struct DecoderRecomposition {
pub transport: Array2<f64>,
pub new_decoder: Array2<f64>,
pub recomposition_residual: f64,
}
pub(crate) fn recompose_decoder_exact_ls(
evaluator: &dyn SaeBasisEvaluator,
decoder: ArrayView2<'_, f64>,
old_phi: ArrayView2<'_, f64>,
new_coords: ArrayView2<'_, f64>,
) -> Result<Option<DecoderRecomposition>, String> {
let m = decoder.nrows();
let (new_phi, new_jet) = evaluator.evaluate(new_coords)?;
if new_phi.ncols() != m
|| new_phi.nrows() != old_phi.nrows()
|| new_jet.dim() != (new_coords.nrows(), m, new_coords.ncols())
{
return Err(format!(
"sae_chart_canonicalization: evaluator returned basis {:?} / jet {:?} at the canonical grid; expected ({}, {m}) with latent_dim {}",
new_phi.dim(),
new_jet.dim(),
old_phi.nrows(),
new_coords.ncols()
));
}
let transport = solve_design_least_squares(new_phi.view(), old_phi)?;
let new_decoder = fast_ab(&transport, &decoder);
let old_fit = fast_ab(&old_phi, &decoder);
let new_fit = fast_ab(&new_phi, &new_decoder);
let mut fit_scale = 0.0_f64;
let mut max_abs = 0.0_f64;
for (a, b) in old_fit.iter().zip(new_fit.iter()) {
fit_scale = fit_scale.max(a.abs()).max(b.abs());
max_abs = max_abs.max((a - b).abs());
}
if !(fit_scale.is_finite() && fit_scale > 0.0 && max_abs.is_finite()) {
return Ok(None);
}
let recomposition_residual = max_abs / fit_scale;
if recomposition_residual > CHART_RECOMPOSITION_REL_TOL {
return Ok(None);
}
Ok(Some(DecoderRecomposition {
transport,
new_decoder,
recomposition_residual,
}))
}
pub const TORUS_FLOW_MAX_HARMONIC: i32 = 2;
pub const TORUS_FLOW_DIFFEO_MIN_DET: f64 = 0.1;
pub const TORUS_FLOW_GUARD_NODES_PER_AXIS: usize = 64;
pub const TORUS_FLOW_GN_MAX_ITERS: usize = 80;
pub const TORUS_FLOW_GN_MAX_REJECTS: usize = 12;
pub const TORUS_TRANSPORT_MIN_NODES_PER_AXIS: usize = 48;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TorusFlowModeKey {
pub component: usize,
pub freq: (i32, i32),
pub is_cos: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct TorusFlowModeSample {
pub component: usize,
pub value: f64,
pub grad: [f64; 2],
}
#[derive(Debug, Clone)]
pub struct TorusFlowBasis {
pub period: f64,
freqs: Vec<(i32, i32)>,
}
impl TorusFlowBasis {
pub fn new(period: f64) -> Result<Self, String> {
if !(period.is_finite() && period > 0.0) {
return Err(format!(
"TorusFlowBasis: period must be finite and positive; got {period}"
));
}
let h = TORUS_FLOW_MAX_HARMONIC;
let mut freqs = Vec::new();
for a in -h..=h {
for b in -h..=h {
if a > 0 || (a == 0 && b > 0) {
freqs.push((a, b));
}
}
}
Ok(Self { period, freqs })
}
pub fn dim(&self) -> usize {
4 * self.freqs.len()
}
pub fn mode_layout(&self) -> Vec<TorusFlowModeKey> {
let mut keys = Vec::with_capacity(self.dim());
for component in 0..2 {
for &freq in &self.freqs {
keys.push(TorusFlowModeKey {
component,
freq,
is_cos: false,
});
keys.push(TorusFlowModeKey {
component,
freq,
is_cos: true,
});
}
}
keys
}
pub fn mode_samples(&self, t: [f64; 2]) -> Vec<TorusFlowModeSample> {
let tau = std::f64::consts::TAU;
let mut out = Vec::with_capacity(self.dim());
for component in 0..2 {
for &(a, b) in &self.freqs {
let w0 = tau * a as f64 / self.period;
let w1 = tau * b as f64 / self.period;
let angle = w0 * t[0] + w1 * t[1];
let s = angle.sin();
let c = angle.cos();
out.push(TorusFlowModeSample {
component,
value: s,
grad: [w0 * c, w1 * c],
});
out.push(TorusFlowModeSample {
component,
value: c,
grad: [-w0 * s, -w1 * s],
});
}
}
out
}
pub fn map_point(&self, theta: &[f64], t: [f64; 2]) -> [f64; 2] {
assert_eq!(theta.len(), self.dim(), "TorusFlowBasis: theta length");
let mut out = t;
for (coef, sample) in theta.iter().zip(self.mode_samples(t)) {
out[sample.component] += coef * sample.value;
}
[
out[0].rem_euclid(self.period),
out[1].rem_euclid(self.period),
]
}
pub fn flow_jacobian(&self, theta: &[f64], t: [f64; 2]) -> [[f64; 2]; 2] {
assert_eq!(theta.len(), self.dim(), "TorusFlowBasis: theta length");
let mut jac = [[1.0, 0.0], [0.0, 1.0]];
for (coef, sample) in theta.iter().zip(self.mode_samples(t)) {
jac[sample.component][0] += coef * sample.grad[0];
jac[sample.component][1] += coef * sample.grad[1];
}
jac
}
pub fn min_jacobian_det_on_grid(&self, theta: &[f64]) -> f64 {
let nodes = TORUS_FLOW_GUARD_NODES_PER_AXIS;
let mut min_det = f64::INFINITY;
for i in 0..nodes {
for j in 0..nodes {
let t = [
self.period * i as f64 / nodes as f64,
self.period * j as f64 / nodes as f64,
];
let jac = self.flow_jacobian(theta, t);
let det = jac[0][0] * jac[1][1] - jac[0][1] * jac[1][0];
min_det = min_det.min(det);
}
}
min_det
}
}
#[derive(Debug, Clone)]
pub struct TorusIsometryFlowReparameterization {
pub new_row_coords: Array2<f64>,
pub new_decoder: Array2<f64>,
pub decoder_transport: Array2<f64>,
pub flow_theta: Vec<f64>,
pub defect_initial: f64,
pub defect_final: f64,
pub profiled_metric_scale: f64,
pub min_flow_jacobian_det: f64,
pub recomposition_residual: f64,
}
struct FlowObjectiveState {
defect: f64,
scale: f64,
a_rows: Vec<[f64; 4]>,
}
fn evaluate_flow_defect(
theta: &[f64],
row_modes: &[Vec<TorusFlowModeSample>],
ghat: &[[f64; 3]],
ghat_norm_sq: f64,
) -> Option<FlowObjectiveState> {
let n = row_modes.len();
let mut a_rows = Vec::with_capacity(n);
let mut cross = 0.0_f64;
for modes in row_modes {
let mut a = [1.0_f64, 0.0, 0.0, 1.0];
for (coef, sample) in theta.iter().zip(modes.iter()) {
a[2 * sample.component] += coef * sample.grad[0];
a[2 * sample.component + 1] += coef * sample.grad[1];
}
a_rows.push(a);
}
for (a, g) in a_rows.iter().zip(ghat.iter()) {
let m00 = a[0] * a[0] + a[2] * a[2];
let m11 = a[1] * a[1] + a[3] * a[3];
let m01 = a[0] * a[1] + a[2] * a[3];
cross += m00 * g[0] + m11 * g[1] + 2.0 * m01 * g[2];
}
let scale = cross / ghat_norm_sq;
if !(scale.is_finite() && scale > 0.0) {
return None;
}
let mut defect = 0.0_f64;
for (a, g) in a_rows.iter().zip(ghat.iter()) {
let m00 = a[0] * a[0] + a[2] * a[2];
let m11 = a[1] * a[1] + a[3] * a[3];
let m01 = a[0] * a[1] + a[2] * a[3];
let r00 = m00 - scale * g[0];
let r11 = m11 - scale * g[1];
let r01 = m01 - scale * g[2];
defect += r00 * r00 + r11 * r11 + 2.0 * r01 * r01;
}
if !defect.is_finite() {
return None;
}
Some(FlowObjectiveState {
defect,
scale,
a_rows,
})
}
pub fn torus_isometry_flow_reparameterization(
evaluator: &dyn SaeBasisEvaluator,
decoder: ArrayView2<'_, f64>,
row_coords: ArrayView2<'_, f64>,
period: f64,
) -> Result<Option<TorusIsometryFlowReparameterization>, String> {
let n = row_coords.nrows();
let m = decoder.nrows();
let p = decoder.ncols();
if row_coords.ncols() != 2 {
return Err(format!(
"torus_isometry_flow_reparameterization: expected (n, 2) row coordinates; got {:?}",
row_coords.dim()
));
}
if n == 0 || m == 0 || p == 0 {
return Ok(None);
}
for &t in row_coords.iter() {
if !t.is_finite() {
return Ok(None);
}
}
let (row_phi, row_jet) = evaluator.evaluate(row_coords)?;
if row_phi.ncols() != m || row_jet.dim() != (n, m, 2) {
return Err(format!(
"torus_isometry_flow_reparameterization: evaluator returned basis {:?} / jet {:?}; expected width {m}, latent_dim 2",
row_phi.dim(),
row_jet.dim()
));
}
let mut g_rows: Vec<[f64; 3]> = Vec::with_capacity(n);
let mut log_det_sum = 0.0_f64;
let mut tangent0 = vec![0.0_f64; p];
let mut tangent1 = vec![0.0_f64; p];
for row in 0..n {
for slot in tangent0.iter_mut() {
*slot = 0.0;
}
for slot in tangent1.iter_mut() {
*slot = 0.0;
}
for bm in 0..m {
let d0 = row_jet[[row, bm, 0]];
let d1 = row_jet[[row, bm, 1]];
if d0 == 0.0 && d1 == 0.0 {
continue;
}
for j in 0..p {
let b = decoder[[bm, j]];
tangent0[j] += d0 * b;
tangent1[j] += d1 * b;
}
}
let mut g00 = 0.0_f64;
let mut g11 = 0.0_f64;
let mut g01 = 0.0_f64;
for j in 0..p {
g00 += tangent0[j] * tangent0[j];
g11 += tangent1[j] * tangent1[j];
g01 += tangent0[j] * tangent1[j];
}
let det = g00 * g11 - g01 * g01;
if !(det.is_finite() && det > 0.0) {
return Ok(None);
}
log_det_sum += 0.5 * det.ln();
g_rows.push([g00, g11, g01]);
}
let g_bar = (log_det_sum / n as f64).exp();
if !(g_bar.is_finite() && g_bar > 0.0) {
return Ok(None);
}
let mut ghat: Vec<[f64; 3]> = Vec::with_capacity(n);
let mut ghat_norm_sq = 0.0_f64;
for g in &g_rows {
let h = [g[0] / g_bar, g[1] / g_bar, g[2] / g_bar];
ghat_norm_sq += h[0] * h[0] + h[1] * h[1] + 2.0 * h[2] * h[2];
ghat.push(h);
}
if !(ghat_norm_sq.is_finite() && ghat_norm_sq > 0.0) {
return Ok(None);
}
let basis = TorusFlowBasis::new(period)?;
let q = basis.dim();
let mut row_modes: Vec<Vec<TorusFlowModeSample>> = Vec::with_capacity(n);
for row in 0..n {
row_modes.push(basis.mode_samples([row_coords[[row, 0]], row_coords[[row, 1]]]));
}
let mut theta = vec![0.0_f64; q];
let Some(mut state) = evaluate_flow_defect(&theta, &row_modes, &ghat, ghat_norm_sq) else {
return Ok(None);
};
let defect_initial = state.defect;
if !(defect_initial > 0.0) {
return Ok(None);
}
let sqrt2 = std::f64::consts::SQRT_2;
let mut lambda = 1.0e-4_f64;
let mut any_accepted = false;
for iteration in 0..TORUS_FLOW_GN_MAX_ITERS {
if iteration + 1 == TORUS_FLOW_GN_MAX_ITERS {
break;
}
let mut jmat = Array2::<f64>::zeros((3 * n, q));
let mut rcol = Array2::<f64>::zeros((3 * n, 1));
for (i, (a, g)) in state.a_rows.iter().zip(ghat.iter()).enumerate() {
let m00 = a[0] * a[0] + a[2] * a[2];
let m11 = a[1] * a[1] + a[3] * a[3];
let m01 = a[0] * a[1] + a[2] * a[3];
rcol[[3 * i, 0]] = m00 - state.scale * g[0];
rcol[[3 * i + 1, 0]] = m11 - state.scale * g[1];
rcol[[3 * i + 2, 0]] = sqrt2 * (m01 - state.scale * g[2]);
for (k, sample) in row_modes[i].iter().enumerate() {
let ac0 = a[2 * sample.component];
let ac1 = a[2 * sample.component + 1];
let s00 = 2.0 * sample.grad[0] * ac0;
let s11 = 2.0 * sample.grad[1] * ac1;
let s01 = sample.grad[0] * ac1 + sample.grad[1] * ac0;
jmat[[3 * i, k]] = s00;
jmat[[3 * i + 1, k]] = s11;
jmat[[3 * i + 2, k]] = sqrt2 * s01;
}
}
let jtj = fast_ata(&jmat);
let jtr = fast_atb(&jmat, &rcol);
let mut rejects = 0usize;
let mut accepted_step = false;
let mut converged = false;
let mut step_norm_sq = 0.0_f64;
while rejects < TORUS_FLOW_GN_MAX_REJECTS {
let mut damped = jtj.clone();
for d in 0..q {
damped[[d, d]] += lambda * (1.0 + jtj[[d, d]]);
}
let factor = match damped.cholesky(FaerSide::Lower) {
Ok(factor) => factor,
Err(_) => {
lambda *= 10.0;
rejects += 1;
continue;
}
};
let mut neg_jtr = jtr.clone();
neg_jtr.mapv_inplace(|v| -v);
let delta = factor.solve_mat(&neg_jtr);
let mut candidate = theta.clone();
step_norm_sq = 0.0;
for k in 0..q {
candidate[k] += delta[[k, 0]];
step_norm_sq += delta[[k, 0]] * delta[[k, 0]];
}
let folded = basis.min_jacobian_det_on_grid(&candidate) <= TORUS_FLOW_DIFFEO_MIN_DET;
let candidate_state = if folded {
None
} else {
evaluate_flow_defect(&candidate, &row_modes, &ghat, ghat_norm_sq)
};
match candidate_state {
Some(next) if next.defect < state.defect => {
let improvement = state.defect - next.defect;
theta = candidate;
state = next;
any_accepted = true;
accepted_step = true;
lambda = (lambda / 10.0).max(1.0e-12);
if improvement <= 1.0e-14 * (1.0 + state.defect) {
converged = true;
}
break;
}
Some(..) | None => {
lambda *= 10.0;
rejects += 1;
}
}
}
if !accepted_step {
break;
}
if converged {
break;
}
let theta_norm_sq: f64 = theta.iter().map(|v| v * v).sum();
if step_norm_sq <= 1.0e-24 * (1.0 + theta_norm_sq) {
break;
}
}
if !any_accepted || !(state.defect < defect_initial) {
return Ok(None);
}
let min_flow_jacobian_det = basis.min_jacobian_det_on_grid(&theta);
if !(min_flow_jacobian_det > TORUS_FLOW_DIFFEO_MIN_DET) {
return Ok(None);
}
let axis_nodes = TORUS_TRANSPORT_MIN_NODES_PER_AXIS.max(3 * (m as f64).sqrt().ceil() as usize);
let grid_rows = axis_nodes * axis_nodes;
let mut grid = Array2::<f64>::zeros((grid_rows, 2));
let mut new_grid = Array2::<f64>::zeros((grid_rows, 2));
for i in 0..axis_nodes {
for j in 0..axis_nodes {
let idx = i * axis_nodes + j;
let u = [
period * i as f64 / axis_nodes as f64,
period * j as f64 / axis_nodes as f64,
];
grid[[idx, 0]] = u[0];
grid[[idx, 1]] = u[1];
let mapped = basis.map_point(&theta, u);
new_grid[[idx, 0]] = mapped[0];
new_grid[[idx, 1]] = mapped[1];
}
}
let (grid_phi, grid_jet) = evaluator.evaluate(grid.view())?;
if grid_phi.ncols() != m || grid_jet.dim() != (grid_rows, m, 2) {
return Err(format!(
"torus_isometry_flow_reparameterization: evaluator returned basis {:?} / jet {:?} on the audit grid; expected width {m}, latent_dim 2",
grid_phi.dim(),
grid_jet.dim()
));
}
let Some(recomposition) =
recompose_decoder_exact_ls(evaluator, decoder, grid_phi.view(), new_grid.view())?
else {
return Ok(None);
};
let mut new_row_coords = Array2::<f64>::zeros((n, 2));
for row in 0..n {
let mapped = basis.map_point(&theta, [row_coords[[row, 0]], row_coords[[row, 1]]]);
new_row_coords[[row, 0]] = mapped[0];
new_row_coords[[row, 1]] = mapped[1];
}
Ok(Some(TorusIsometryFlowReparameterization {
new_row_coords,
new_decoder: recomposition.new_decoder,
decoder_transport: recomposition.transport,
flow_theta: theta,
defect_initial,
defect_final: state.defect,
profiled_metric_scale: state.scale,
min_flow_jacobian_det,
recomposition_residual: recomposition.recomposition_residual,
}))
}