use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use super::{SaeBasisEvaluator, SaeManifoldAtom, SaeManifoldTerm};
pub fn decoder_frobenius_norm(decoder: ArrayView2<'_, f64>) -> f64 {
decoder.iter().map(|v| v * v).sum::<f64>().sqrt()
}
pub fn retract_decoder_unit_frobenius(atom: &mut SaeManifoldAtom) -> bool {
let norm = decoder_frobenius_norm(atom.decoder_coefficients.view());
if !(norm.is_finite() && norm > 0.0) {
return false;
}
if (norm - 1.0).abs() <= f64::EPSILON {
return false;
}
atom.absorb_decoder_norm_into_log_amplitude(f64::MIN_POSITIVE);
true
}
pub fn unit_frobenius_tangent_projection(
decoder: ArrayView2<'_, f64>,
ambient_grad: ArrayView2<'_, f64>,
) -> Array2<f64> {
let bb = decoder.iter().map(|v| v * v).sum::<f64>();
let mut out = ambient_grad.to_owned();
if !(bb > 0.0) {
return out;
}
let gb: f64 = ambient_grad
.iter()
.zip(decoder.iter())
.map(|(g, b)| g * b)
.sum();
let coeff = gb / bb;
for (o, b) in out.iter_mut().zip(decoder.iter()) {
*o -= coeff * b;
}
out
}
#[derive(Debug, Clone)]
pub struct LogAmplitudeHoyerEnergy {
pub value: f64,
pub grad: Array1<f64>,
pub hess: Array2<f64>,
}
pub fn log_amplitude_hoyer_energy(s: ArrayView1<'_, f64>, lambda: f64) -> LogAmplitudeHoyerEnergy {
let k = s.len();
let mut grad = Array1::<f64>::zeros(k);
let mut hess = Array2::<f64>::zeros((k, k));
if k <= 1 {
return LogAmplitudeHoyerEnergy {
value: 0.0,
grad,
hess,
};
}
let smax = s.iter().copied().fold(f64::NEG_INFINITY, f64::max);
if !smax.is_finite() {
return LogAmplitudeHoyerEnergy {
value: 0.0,
grad,
hess,
};
}
let a: Vec<f64> = s.iter().map(|&sk| (sk - smax).exp()).collect();
let l1: f64 = a.iter().sum();
let l2_sq: f64 = a.iter().map(|v| v * v).sum();
let l2 = l2_sq.sqrt();
if !(l2 > 0.0 && l1 > 0.0) {
return LogAmplitudeHoyerEnergy {
value: 0.0,
grad,
hess,
};
}
let r = l1 / l2;
let u: Vec<f64> = a.iter().map(|v| v / l2).collect();
let value = lambda * r;
for k1 in 0..k {
grad[k1] = lambda * u[k1] * (1.0 - r * u[k1]);
}
for k1 in 0..k {
for j in 0..k {
let diag = if k1 == j {
u[k1] * (1.0 - 2.0 * r * u[k1])
} else {
0.0
};
let cross = -u[k1] * u[j] * (u[j] + u[k1]) + 3.0 * r * u[k1] * u[k1] * u[j] * u[j];
hess[[k1, j]] = lambda * (diag + cross);
}
}
LogAmplitudeHoyerEnergy { value, grad, hess }
}
pub fn sample_decoded_curve(
evaluator: &dyn SaeBasisEvaluator,
decoder: ArrayView2<'_, f64>,
log_amplitude: f64,
coords: ArrayView1<'_, f64>,
) -> Result<Array2<f64>, String> {
let n = coords.len();
let mut coords2 = Array2::<f64>::zeros((n, 1));
for i in 0..n {
coords2[[i, 0]] = coords[i];
}
let (phi, _jet) = evaluator.evaluate(coords2.view())?;
if phi.ncols() != decoder.nrows() {
return Err(format!(
"sample_decoded_curve: basis width {} != decoder rows {}",
phi.ncols(),
decoder.nrows()
));
}
let mut pts = phi.dot(&decoder);
if log_amplitude != 0.0 {
let amp = log_amplitude.exp();
pts.mapv_inplace(|v| v * amp);
}
Ok(pts)
}
#[derive(Debug, Clone)]
pub struct AffineChartTransition {
pub slope: f64,
pub offset: f64,
pub coord_residual: f64,
pub geometric_residual: f64,
}
impl AffineChartTransition {
pub fn same_manifold(&self, coord_scale: f64, rel_tol: f64) -> bool {
let slope_ok = (self.slope.abs() - 1.0).abs() <= rel_tol;
let coord_ok = coord_scale > 0.0 && self.coord_residual <= rel_tol * coord_scale;
let geom_ok = self.geometric_residual <= rel_tol;
slope_ok && coord_ok && geom_ok
}
}
pub fn affine_chart_transition(
points_a: ArrayView2<'_, f64>,
coords_a: ArrayView1<'_, f64>,
points_b: ArrayView2<'_, f64>,
coords_b: ArrayView1<'_, f64>,
period_a: Option<f64>,
) -> Result<AffineChartTransition, String> {
let (na, p) = points_a.dim();
let (nb, pb) = points_b.dim();
if p != pb {
return Err(format!(
"affine_chart_transition: output dims differ (a: {p}, b: {pb})"
));
}
if na != coords_a.len() || nb != coords_b.len() {
return Err(format!(
"affine_chart_transition: point/coord length mismatch (a: {na} vs {}, b: {nb} vs {})",
coords_a.len(),
coords_b.len()
));
}
if na < 2 || nb < 2 {
return Err("affine_chart_transition: need at least two samples per curve".into());
}
let mut centroid = vec![0.0_f64; p];
for i in 0..na {
for j in 0..p {
centroid[j] += points_a[[i, j]];
}
}
for c in centroid.iter_mut() {
*c /= na as f64;
}
let mut scale_sq = 0.0_f64;
for i in 0..na {
for j in 0..p {
let d = points_a[[i, j]] - centroid[j];
scale_sq += d * d;
}
}
let curve_scale = (scale_sq / na as f64).sqrt();
let mut xs = Vec::with_capacity(nb); let mut ys = Vec::with_capacity(nb); let mut dist_sum = 0.0_f64;
for jb in 0..nb {
let mut best = f64::INFINITY;
let mut best_i = 0usize;
for ia in 0..na {
let mut d = 0.0_f64;
for c in 0..p {
let diff = points_b[[jb, c]] - points_a[[ia, c]];
d += diff * diff;
}
if d < best {
best = d;
best_i = ia;
}
}
dist_sum += best.sqrt();
xs.push(coords_b[jb]);
ys.push(coords_a[best_i]);
}
let geometric_residual = if curve_scale > 0.0 {
(dist_sum / nb as f64) / curve_scale
} else {
f64::INFINITY
};
let mut order: Vec<usize> = (0..nb).collect();
order.sort_by(|&i, &j| {
xs[i]
.partial_cmp(&xs[j])
.unwrap_or(std::cmp::Ordering::Equal)
});
let xo: Vec<f64> = order.iter().map(|&i| xs[i]).collect();
let mut yo: Vec<f64> = order.iter().map(|&i| ys[i]).collect();
if let Some(pp) = period_a {
if pp > 0.0 {
for idx in 1..yo.len() {
let mut d = yo[idx] - yo[idx - 1];
while d > 0.5 * pp {
yo[idx] -= pp;
d -= pp;
}
while d < -0.5 * pp {
yo[idx] += pp;
d += pp;
}
}
}
}
let m = xo.len() as f64;
let mean_x = xo.iter().sum::<f64>() / m;
let mean_y = yo.iter().sum::<f64>() / m;
let mut sxx = 0.0_f64;
let mut sxy = 0.0_f64;
for idx in 0..xo.len() {
let dx = xo[idx] - mean_x;
sxx += dx * dx;
sxy += dx * (yo[idx] - mean_y);
}
if !(sxx > 0.0) {
return Err(
"affine_chart_transition: curve B coordinate has zero spread; slope undefined".into(),
);
}
let slope = sxy / sxx;
let offset = mean_y - slope * mean_x;
let mut resid_sq = 0.0_f64;
for idx in 0..xo.len() {
let pred = slope * xo[idx] + offset;
let e = yo[idx] - pred;
resid_sq += e * e;
}
let coord_residual = (resid_sq / m).sqrt();
Ok(AffineChartTransition {
slope,
offset,
coord_residual,
geometric_residual,
})
}
impl SaeManifoldTerm {
pub fn retract_decoder_gauge_in_loop(&mut self) -> usize {
let mut retracted = 0usize;
for atom in self.atoms.iter_mut() {
if retract_decoder_unit_frobenius(atom) {
retracted += 1;
}
}
retracted
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array3, array};
#[derive(Debug)]
struct AffineLineEvaluator;
impl SaeBasisEvaluator for AffineLineEvaluator {
fn evaluate(
&self,
coords: ArrayView2<'_, f64>,
) -> Result<(Array2<f64>, Array3<f64>), String> {
let n = coords.nrows();
let mut phi = Array2::<f64>::zeros((n, 2));
let mut jet = Array3::<f64>::zeros((n, 2, 1));
for i in 0..n {
let t = coords[[i, 0]];
phi[[i, 0]] = 1.0;
phi[[i, 1]] = t;
jet[[i, 0, 0]] = 0.0;
jet[[i, 1, 0]] = 1.0;
}
Ok((phi, jet))
}
fn second_jet_dyn(
&self,
coords: ArrayView2<'_, f64>,
) -> Option<Result<ndarray::Array4<f64>, String>> {
if coords.ncols() != 1 {
return Some(Err(format!(
"AffineLineEvaluator::second_jet_dyn: d = 1 evaluator got {} coords",
coords.ncols()
)));
}
None
}
fn third_jet_dyn(
&self,
coords: ArrayView2<'_, f64>,
) -> Option<Result<ndarray::Array5<f64>, String>> {
if coords.ncols() != 1 {
return Some(Err(format!(
"AffineLineEvaluator::third_jet_dyn: d = 1 evaluator got {} coords",
coords.ncols()
)));
}
None
}
}
#[test]
fn unit_frobenius_tangent_projection_kills_radial_component() {
let b = array![[0.6_f64, 0.0], [0.0, 0.8]]; let radial = b.mapv(|v| 2.5 * v);
let proj = unit_frobenius_tangent_projection(b.view(), radial.view());
let worst = proj.iter().fold(0.0_f64, |a, &v| a.max(v.abs()));
assert!(
worst < 1e-12,
"radial gradient must project to 0, got {worst}"
);
let tangent = array![[0.0_f64, 1.0], [-1.0, 0.0]]; let proj_t = unit_frobenius_tangent_projection(b.view(), tangent.view());
let drift = proj_t
.iter()
.zip(tangent.iter())
.map(|(a, c)| (a - c).abs())
.fold(0.0_f64, f64::max);
assert!(
drift < 1e-12,
"tangent gradient must pass through, drift {drift}"
);
}
#[test]
fn hoyer_energy_gradient_and_hessian_match_fd() {
let s = array![0.3_f64, -0.7, 1.1, 0.05];
let lambda = 1.7_f64;
let base = log_amplitude_hoyer_energy(s.view(), lambda);
let h = 1e-6_f64;
let k = s.len();
for i in 0..k {
let mut sp = s.clone();
sp[i] += h;
let mut sm = s.clone();
sm[i] -= h;
let vp = log_amplitude_hoyer_energy(sp.view(), lambda).value;
let vm = log_amplitude_hoyer_energy(sm.view(), lambda).value;
let fd = (vp - vm) / (2.0 * h);
assert!(
(base.grad[i] - fd).abs() <= 1e-6 * (1.0 + fd.abs()),
"grad[{i}] {} != FD {fd}",
base.grad[i]
);
}
for i in 0..k {
let mut sp = s.clone();
sp[i] += h;
let mut sm = s.clone();
sm[i] -= h;
let gp = log_amplitude_hoyer_energy(sp.view(), lambda).grad;
let gm = log_amplitude_hoyer_energy(sm.view(), lambda).grad;
for j in 0..k {
let fd = (gp[j] - gm[j]) / (2.0 * h);
assert!(
(base.hess[[j, i]] - fd).abs() <= 1e-5 * (1.0 + fd.abs()),
"hess[{j},{i}] {} != FD {fd}",
base.hess[[j, i]]
);
}
}
let shifted = s.mapv(|v| v + 3.4);
let e_shift = log_amplitude_hoyer_energy(shifted.view(), lambda).value;
assert!(
(e_shift - base.value).abs() <= 1e-9 * (1.0 + base.value.abs()),
"Hoyer energy must be invariant to a common amplitude shift"
);
}
#[test]
fn hoyer_energy_prefers_sparse_over_dense() {
let sparse = array![2.0_f64, -3.0, -3.0, -3.0];
let dense = array![0.0_f64, 0.0, 0.0, 0.0];
let es = log_amplitude_hoyer_energy(sparse.view(), 1.0).value;
let ed = log_amplitude_hoyer_energy(dense.view(), 1.0).value;
assert!(es < ed, "sparse energy {es} must be below dense {ed}");
assert!(
(ed - (4.0_f64).sqrt()).abs() < 1e-9,
"dense ratio must be √K"
);
}
#[test]
fn retract_decoder_unit_frobenius_is_image_frozen() {
let coords = array![[0.0_f64], [0.25], [0.5], [0.75], [1.0]];
let ev = AffineLineEvaluator;
let (phi, jet) = ev.evaluate(coords.view()).unwrap();
let decoder = array![[2.0_f64, -1.0], [3.0, 0.5]]; let atom = SaeManifoldAtom::new(
"line",
super::super::SaeAtomBasisKind::Linear,
1,
phi,
jet,
decoder.clone(),
Array2::<f64>::eye(2),
)
.unwrap()
.with_basis_evaluator(std::sync::Arc::new(AffineLineEvaluator));
let before = sample_decoded_curve(
&ev,
atom.decoder_coefficients.view(),
atom.log_amplitude,
coords.column(0),
)
.unwrap();
let mut atom = atom;
let applied = retract_decoder_unit_frobenius(&mut atom);
assert!(applied, "a non-unit decoder must be retracted");
let norm = decoder_frobenius_norm(atom.decoder_coefficients.view());
assert!(
(norm - 1.0).abs() < 1e-12,
"‖B‖_F must be pinned to 1, got {norm}"
);
let after = sample_decoded_curve(
&ev,
atom.decoder_coefficients.view(),
atom.log_amplitude,
coords.column(0),
)
.unwrap();
let drift = before
.iter()
.zip(after.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(
drift < 1e-10,
"retraction must be image-frozen, drift {drift}"
);
assert!(
!retract_decoder_unit_frobenius(&mut atom),
"retraction must be idempotent"
);
}
#[test]
fn affine_transition_detects_same_line_with_reflection_and_offset() {
let ev = AffineLineEvaluator;
let d = array![[0.0_f64, 0.0], [0.6, 0.8]]; let ca = Array1::linspace(0.0, 1.0, 11);
let pts_a = sample_decoded_curve(&ev, d.view(), 0.0, ca.view()).unwrap();
let cb = Array1::linspace(0.0, 1.0, 11);
let db = array![[0.6_f64, 0.8], [-0.6, -0.8]]; let pts_b = sample_decoded_curve(&ev, db.view(), 0.0, cb.view()).unwrap();
let tr = affine_chart_transition(pts_a.view(), ca.view(), pts_b.view(), cb.view(), None)
.unwrap();
assert!(
(tr.slope + 1.0).abs() < 1e-6,
"slope must be -1, got {}",
tr.slope
);
assert!(
(tr.offset - 1.0).abs() < 1e-6,
"offset must be 1, got {}",
tr.offset
);
assert!(
tr.coord_residual < 1e-6,
"coord residual {}",
tr.coord_residual
);
assert!(
tr.geometric_residual < 1e-6,
"geometric residual {}",
tr.geometric_residual
);
assert!(tr.same_manifold(1.0, 1e-3), "must be flagged same-manifold");
}
#[test]
fn affine_transition_rejects_disjoint_curve() {
let ev = AffineLineEvaluator;
let da = array![[0.0_f64, 0.0], [1.0, 0.0]]; let db = array![[0.0_f64, 5.0], [1.0, 0.0]]; let ca = Array1::linspace(0.0, 1.0, 11);
let cb = Array1::linspace(0.0, 1.0, 11);
let pts_a = sample_decoded_curve(&ev, da.view(), 0.0, ca.view()).unwrap();
let pts_b = sample_decoded_curve(&ev, db.view(), 0.0, cb.view()).unwrap();
let tr = affine_chart_transition(pts_a.view(), ca.view(), pts_b.view(), cb.view(), None)
.unwrap();
assert!(
tr.geometric_residual > 1.0,
"disjoint curve must have large geometric residual"
);
assert!(
!tr.same_manifold(1.0, 1e-2),
"disjoint curve must be rejected"
);
}
}