use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use super::manifold::{GEOMETRY_EPS, GeometryError, GeometryResult, RiemannianManifold};
use crate::families::jet_tower::Tower4;
const CS_SERIES_U_MAX: f64 = 0.5;
const CS_SERIES_TERMS: usize = 18;
const T_SERIES_W_MAX: f64 = 0.25;
const T_SERIES_TERMS: usize = 48;
const MOBIUS_DENOM_EPS: f64 = 1.0e-14;
pub fn cs_stacks(u: f64) -> ([f64; 5], [f64; 5]) {
if u.abs() <= CS_SERIES_U_MAX {
let mut c = [0.0; 5];
let mut s = [0.0; 5];
for j in 0..5 {
let mut term_c = 1.0;
let mut term_s = 1.0;
for f in 1..=j {
let fj = f as f64;
term_c *= -fj / ((2.0 * fj - 1.0) * (2.0 * fj));
term_s *= -fj / ((2.0 * fj) * (2.0 * fj + 1.0));
}
let mut acc_c = term_c;
let mut acc_s = term_s;
for m in j..(j + CS_SERIES_TERMS) {
let mf = m as f64;
let jf = j as f64;
let ratio_c =
-u * (mf + 1.0) / ((mf + 1.0 - jf) * (2.0 * mf + 1.0) * (2.0 * mf + 2.0));
let ratio_s =
-u * (mf + 1.0) / ((mf + 1.0 - jf) * (2.0 * mf + 2.0) * (2.0 * mf + 3.0));
term_c *= ratio_c;
term_s *= ratio_s;
acc_c += term_c;
acc_s += term_s;
}
c[j] = acc_c;
s[j] = acc_s;
}
(c, s)
} else {
let (c0, s0) = if u > 0.0 {
let r = u.sqrt();
(r.cos(), r.sin() / r)
} else {
let r = (-u).sqrt();
(r.cosh(), r.sinh() / r)
};
let mut c = [c0, 0.0, 0.0, 0.0, 0.0];
let mut s = [s0, 0.0, 0.0, 0.0, 0.0];
for j in 0..4 {
s[j + 1] = (c[j] - (2.0 * j as f64 + 1.0) * s[j]) / (2.0 * u);
c[j + 1] = -s[j] / 2.0;
}
(c, s)
}
}
pub fn t_stacks(w: f64) -> [f64; 5] {
if w.abs() <= T_SERIES_W_MAX {
let mut t = [0.0; 5];
for (j, slot) in t.iter_mut().enumerate() {
let mut term = 1.0;
for f in 1..=j {
let fj = f as f64;
term *= -fj * (2.0 * fj - 1.0) / (2.0 * fj + 1.0);
}
let mut acc = term;
for m in j..(j + T_SERIES_TERMS) {
let mf = m as f64;
let jf = j as f64;
term *= -w * (mf + 1.0) * (2.0 * mf + 1.0) / ((mf + 1.0 - jf) * (2.0 * mf + 3.0));
acc += term;
}
*slot = acc;
}
t
} else {
let t0 = if w > 0.0 {
let r = w.sqrt();
r.atan() / r
} else {
let r = (-w).sqrt();
r.atanh() / r
};
let mut t = [t0, 0.0, 0.0, 0.0, 0.0];
let mut r_j = 1.0 / (1.0 + w); for j in 0..4 {
t[j + 1] = (r_j - (2.0 * j as f64 + 1.0) * t[j]) / (2.0 * w);
r_j *= -((j + 1) as f64) / (1.0 + w);
}
t
}
}
#[derive(Clone, Debug)]
pub struct ConstantCurvature {
pub kappa: f64,
pub dim: usize,
}
impl ConstantCurvature {
pub fn new(dim: usize, kappa: f64) -> Self {
Self { kappa, dim }
}
fn check_len(&self, context: &'static str, got: usize) -> GeometryResult<()> {
if got != self.dim {
return Err(GeometryError::DimensionMismatch {
context,
expected: self.dim,
got,
});
}
Ok(())
}
fn chart_gauge(&self, x: ArrayView1<'_, f64>) -> GeometryResult<f64> {
let gauge = 1.0 + self.kappa * x.dot(&x);
if gauge <= GEOMETRY_EPS {
return Err(GeometryError::InvalidPoint(
"constant-curvature point outside the κ-stereographic chart",
));
}
Ok(gauge)
}
pub fn conformal_factor(&self, x: ArrayView1<'_, f64>) -> GeometryResult<f64> {
Ok(2.0 / self.chart_gauge(x)?)
}
pub fn mobius_add(
&self,
x: ArrayView1<'_, f64>,
y: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
let k = self.kappa;
let xy = x.dot(&y);
let xx = x.dot(&x);
let yy = y.dot(&y);
let denom = 1.0 - 2.0 * k * xy + k * k * xx * yy;
if denom.abs() <= MOBIUS_DENOM_EPS {
return Err(GeometryError::Singular(
"Möbius addition at the κ>0 antipodal point",
));
}
let a = 1.0 - 2.0 * k * xy - k * yy;
let b = 1.0 + k * xx;
let mut out = Array1::zeros(x.len());
for i in 0..x.len() {
out[i] = (a * x[i] + b * y[i]) / denom;
}
Ok(out)
}
pub fn distance(&self, x: ArrayView1<'_, f64>, y: ArrayView1<'_, f64>) -> GeometryResult<f64> {
self.check_len("constant-curvature distance x", x.len())?;
self.check_len("constant-curvature distance y", y.len())?;
self.chart_gauge(x)?;
self.chart_gauge(y)?;
let neg_x = x.mapv(|v| -v);
let w = self.mobius_add(neg_x.view(), y)?;
let nw2 = w.dot(&w);
Ok(2.0 * nw2.sqrt() * t_stacks(self.kappa * nw2)[0])
}
fn tn(&self, t: f64) -> GeometryResult<f64> {
let (c, s) = cs_stacks(self.kappa * t * t);
if c[0].abs() <= GEOMETRY_EPS {
return Err(GeometryError::Singular(
"constant-curvature exp map at a conjugate point (cos(√κ t) = 0)",
));
}
Ok(t * s[0] / c[0])
}
fn gyration(
&self,
a: ArrayView1<'_, f64>,
b: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
let bv = self.mobius_add(b, v)?;
let abv = self.mobius_add(a, bv.view())?;
let ab = self.mobius_add(a, b)?;
let neg_ab = ab.mapv(|z| -z);
self.mobius_add(neg_ab.view(), abv.view())
}
}
impl RiemannianManifold for ConstantCurvature {
fn dim(&self) -> usize {
self.dim
}
fn tangent_basis(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Array2<f64>> {
self.check_len("constant-curvature tangent_basis point", point.len())?;
self.chart_gauge(point)?;
Ok(Array2::eye(self.dim))
}
fn exp_map(
&self,
point: ArrayView1<'_, f64>,
tangent_vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
self.check_len("constant-curvature exp point", point.len())?;
self.check_len("constant-curvature exp tangent", tangent_vec.len())?;
let gauge = self.chart_gauge(point)?;
let n = tangent_vec.dot(&tangent_vec).sqrt();
if n <= GEOMETRY_EPS {
return Ok(point.to_owned());
}
let t = n / gauge; let scale = self.tn(t)? / n;
let step = tangent_vec.mapv(|z| z * scale);
self.mobius_add(point, step.view())
}
fn log_map(
&self,
p_from: ArrayView1<'_, f64>,
p_to: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
self.check_len("constant-curvature log from", p_from.len())?;
self.check_len("constant-curvature log to", p_to.len())?;
let gauge = self.chart_gauge(p_from)?;
self.chart_gauge(p_to)?;
let neg_x = p_from.mapv(|v| -v);
let w = self.mobius_add(neg_x.view(), p_to)?;
let coeff = gauge * t_stacks(self.kappa * w.dot(&w))[0];
Ok(w.mapv(|z| z * coeff))
}
fn parallel_transport(
&self,
point_along: ArrayView2<'_, f64>,
vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
self.check_len("constant-curvature transport vector", vec.len())?;
if point_along.nrows() < 2 {
return Ok(vec.to_owned());
}
let mut carried = vec.to_owned();
for seg in 0..(point_along.nrows() - 1) {
let a = point_along.row(seg);
let b = point_along.row(seg + 1);
self.check_len("constant-curvature transport waypoint", a.len())?;
let lam_ratio = self.chart_gauge(b)? / self.chart_gauge(a)?; let neg_a = a.mapv(|z| -z);
carried = self
.gyration(b, neg_a.view(), carried.view())?
.mapv(|z| z * lam_ratio);
}
Ok(carried)
}
fn metric_tensor(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Array2<f64>> {
self.check_len("constant-curvature metric point", point.len())?;
let lam = self.conformal_factor(point)?;
Ok(Array2::eye(self.dim) * (lam * lam))
}
fn christoffel_symbols(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Vec<Array2<f64>>> {
self.check_len("constant-curvature Christoffel point", point.len())?;
let lam = self.conformal_factor(point)?;
let phi: Vec<f64> = point.iter().map(|&xi| -self.kappa * lam * xi).collect();
let d = self.dim;
let mut out = Vec::with_capacity(d);
for k in 0..d {
let mut gamma_k = Array2::zeros((d, d));
for i in 0..d {
for j in 0..d {
let mut val = 0.0;
if i == k {
val += phi[j];
}
if j == k {
val += phi[i];
}
if i == j {
val -= phi[k];
}
gamma_k[[i, j]] = val;
}
}
out.push(gamma_k);
}
Ok(out)
}
fn sectional_curvature(
&self,
point: ArrayView1<'_, f64>,
tangent_pair: (ArrayView1<'_, f64>, ArrayView1<'_, f64>),
) -> GeometryResult<f64> {
self.check_len("constant-curvature sectional point", point.len())?;
self.check_len("constant-curvature sectional u", tangent_pair.0.len())?;
self.check_len("constant-curvature sectional v", tangent_pair.1.len())?;
self.chart_gauge(point)?;
Ok(self.kappa)
}
fn exp_map_vjp(
&self,
point: ArrayView1<'_, f64>,
tangent_vec: ArrayView1<'_, f64>,
grad_output: ArrayView1<'_, f64>,
) -> GeometryResult<(Array1<f64>, Array1<f64>)> {
self.check_len("constant-curvature exp_map_vjp point", point.len())?;
self.check_len("constant-curvature exp_map_vjp tangent", tangent_vec.len())?;
self.check_len(
"constant-curvature exp_map_vjp grad_output",
grad_output.len(),
)?;
let k = self.kappa;
if k.abs() <= GEOMETRY_EPS {
return Ok((grad_output.to_owned(), grad_output.to_owned()));
}
let d = point.len();
let n = tangent_vec.dot(&tangent_vec).sqrt();
if n <= GEOMETRY_EPS {
return Ok((grad_output.to_owned(), grad_output.to_owned()));
}
let gauge = self.chart_gauge(point)?; let t = n / gauge; let tau = self.tn(t)?; let scale = tau / n;
let step = tangent_vec.mapv(|z| z * scale);
let p = point.dot(&step); let xx = point.dot(&point); let ss = step.dot(&step); let a = 1.0 - 2.0 * k * p - k * ss;
let b = 1.0 + k * xx; let denom = 1.0 - 2.0 * k * p + k * k * xx * ss;
if denom.abs() <= MOBIUS_DENOM_EPS {
return Err(GeometryError::Singular(
"Möbius addition at the κ>0 antipodal point",
));
}
let mut y = Array1::zeros(d);
for i in 0..d {
y[i] = (a * point[i] + b * step[i]) / denom;
}
let g = grad_output;
let yx = g.dot(&point); let ys = g.dot(&step); let yy = g.dot(&y); let inv_d = 1.0 / denom;
let mut x_bar = Array1::zeros(d);
let mut step_bar = Array1::zeros(d);
for j in 0..d {
x_bar[j] = (-2.0 * k * step[j] * yx + a * g[j] + 2.0 * k * point[j] * ys) * inv_d
- yy * (-2.0 * k * step[j] + 2.0 * k * k * point[j] * ss) * inv_d;
step_bar[j] = ((-2.0 * k * point[j] - 2.0 * k * step[j]) * yx + b * g[j]) * inv_d
- yy * (-2.0 * k * point[j] + 2.0 * k * k * xx * step[j]) * inv_d;
}
let scale_bar = step_bar.dot(&tangent_vec);
let mut v_bar = step_bar.mapv(|z| z * scale);
let tau_bar = scale_bar / n;
let mut n_bar = -scale_bar * tau / (n * n);
let t_bar = tau_bar * (1.0 + k * tau * tau);
n_bar += t_bar / gauge;
let gauge_bar = -t_bar * n / (gauge * gauge);
for j in 0..d {
x_bar[j] += gauge_bar * 2.0 * k * point[j];
}
for j in 0..d {
v_bar[j] += n_bar * tangent_vec[j] / n;
}
Ok((x_bar, v_bar))
}
}
type KJet = Tower4<1>;
fn kjet_mobius_w(
kappa: KJet,
x: ArrayView1<'_, f64>,
y: ArrayView1<'_, f64>,
) -> GeometryResult<Vec<KJet>> {
let xy = -x.dot(&y);
let xx = x.dot(&x);
let yy = y.dot(&y);
let a = -(kappa * (2.0 * xy + yy)) + 1.0;
let b = kappa * xx + 1.0;
let denom = (kappa * kappa) * (xx * yy) - kappa * (2.0 * xy) + 1.0;
if denom.v.abs() <= MOBIUS_DENOM_EPS {
return Err(GeometryError::Singular(
"Möbius addition at the κ>0 antipodal point",
));
}
let inv = denom.recip();
Ok((0..x.len())
.map(|i| (a * (-x[i]) + b * y[i]) * inv)
.collect())
}
pub fn distance_kappa_jet(
manifold: &ConstantCurvature,
x: ArrayView1<'_, f64>,
y: ArrayView1<'_, f64>,
) -> GeometryResult<(f64, f64, f64)> {
manifold.check_len("constant-curvature distance-jet x", x.len())?;
manifold.check_len("constant-curvature distance-jet y", y.len())?;
manifold.chart_gauge(x)?;
manifold.chart_gauge(y)?;
let kappa = KJet::variable(manifold.kappa, 0);
let w = kjet_mobius_w(kappa, x, y)?;
let mut nw2 = KJet::constant(0.0);
for wi in &w {
nw2 = nw2 + *wi * *wi;
}
if nw2.v <= GEOMETRY_EPS * GEOMETRY_EPS {
return Ok((0.0, 0.0, 0.0));
}
let arg = kappa * nw2;
let t = arg.compose_unary(t_stacks(arg.v));
let d = nw2.sqrt() * t * 2.0;
Ok((d.v, d.g[0], d.h[0][0]))
}
pub fn log_map_kappa_jet(
manifold: &ConstantCurvature,
p_from: ArrayView1<'_, f64>,
p_to: ArrayView1<'_, f64>,
) -> GeometryResult<(Array1<f64>, Array1<f64>, Array1<f64>)> {
manifold.check_len("constant-curvature log-jet from", p_from.len())?;
manifold.check_len("constant-curvature log-jet to", p_to.len())?;
manifold.chart_gauge(p_from)?;
manifold.chart_gauge(p_to)?;
let kappa = KJet::variable(manifold.kappa, 0);
let w = kjet_mobius_w(kappa, p_from, p_to)?;
let mut nw2 = KJet::constant(0.0);
for wi in &w {
nw2 = nw2 + *wi * *wi;
}
let arg = kappa * nw2;
let t = arg.compose_unary(t_stacks(arg.v));
let gauge = kappa * p_from.dot(&p_from) + 1.0;
let coeff = gauge * t;
let d = p_from.len();
let mut value = Array1::zeros(d);
let mut dk = Array1::zeros(d);
let mut dkk = Array1::zeros(d);
for i in 0..d {
let li = coeff * w[i];
value[i] = li.v;
dk[i] = li.g[0];
dkk[i] = li.h[0][0];
}
Ok((value, dk, dkk))
}
pub fn exp_map_kappa_jet(
manifold: &ConstantCurvature,
point: ArrayView1<'_, f64>,
tangent_vec: ArrayView1<'_, f64>,
) -> GeometryResult<(Array1<f64>, Array1<f64>, Array1<f64>)> {
manifold.check_len("constant-curvature exp-jet point", point.len())?;
manifold.check_len("constant-curvature exp-jet tangent", tangent_vec.len())?;
manifold.chart_gauge(point)?;
let d = point.len();
let n = tangent_vec.dot(&tangent_vec).sqrt();
if n <= GEOMETRY_EPS {
return Ok((point.to_owned(), Array1::zeros(d), Array1::zeros(d)));
}
let kappa = KJet::variable(manifold.kappa, 0);
let xx = point.dot(&point);
let gauge = kappa * xx + 1.0;
let t = gauge.recip() * n;
let arg = kappa * (t * t);
let (cstk, sstk) = cs_stacks(arg.v);
let c = arg.compose_unary(cstk);
if c.v.abs() <= GEOMETRY_EPS {
return Err(GeometryError::Singular(
"constant-curvature exp-jet at a conjugate point (cos(√κ t) = 0)",
));
}
let s = arg.compose_unary(sstk);
let tn = t * s * c.recip();
let scale = tn * (1.0 / n);
let step: Vec<KJet> = (0..d).map(|i| scale * tangent_vec[i]).collect();
let mut xs = KJet::constant(0.0); let mut ss = KJet::constant(0.0); for i in 0..d {
xs = xs + step[i] * point[i];
ss = ss + step[i] * step[i];
}
let two_k_xs = (kappa * 2.0) * xs; let denom = -two_k_xs + (kappa * kappa) * (ss * xx) + 1.0;
if denom.v.abs() <= MOBIUS_DENOM_EPS {
return Err(GeometryError::Singular(
"Möbius addition at the κ>0 antipodal point",
));
}
let a = -two_k_xs + (-(kappa * ss)) + 1.0;
let b = gauge;
let inv = denom.recip();
let mut value = Array1::zeros(d);
let mut dk = Array1::zeros(d);
let mut dkk = Array1::zeros(d);
for i in 0..d {
let oi = (a * point[i] + b * step[i]) * inv;
value[i] = oi.v;
dk[i] = oi.g[0];
dkk[i] = oi.h[0][0];
}
Ok((value, dk, dkk))
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn scalar_stacks_are_fd_consistent_across_branches() {
let h = 1e-6;
for &u in &[
-3.0, -0.6, -0.49, -0.2, -1e-9, 0.0, 1e-9, 0.2, 0.49, 0.6, 3.0,
] {
let up = cs_stacks(u + h);
let dn = cs_stacks(u - h);
let at = cs_stacks(u);
for j in 0..4 {
let fd_c = (up.0[j] - dn.0[j]) / (2.0 * h);
let fd_s = (up.1[j] - dn.1[j]) / (2.0 * h);
assert!(
(at.0[j + 1] - fd_c).abs() <= 1e-7 * fd_c.abs().max(1.0),
"C stack order {j} at u={u}: analytic {} fd {}",
at.0[j + 1],
fd_c
);
assert!(
(at.1[j + 1] - fd_s).abs() <= 1e-7 * fd_s.abs().max(1.0),
"S stack order {j} at u={u}: analytic {} fd {}",
at.1[j + 1],
fd_s
);
}
}
for &w in &[-0.6, -0.26, -0.1, -1e-9, 0.0, 1e-9, 0.1, 0.26, 0.8, 4.0] {
let up = t_stacks(w + 1e-7);
let dn = t_stacks(w - 1e-7);
let at = t_stacks(w);
for j in 0..4 {
let fd = (up[j] - dn[j]) / 2e-7;
assert!(
(at[j + 1] - fd).abs() <= 1e-6 * fd.abs().max(1.0),
"T stack order {j} at w={w}: analytic {} fd {}",
at[j + 1],
fd
);
}
}
}
#[test]
fn classical_members_match_closed_forms() {
let y: ndarray::Array1<f64> = array![0.3, -0.2, 0.1];
let origin: ndarray::Array1<f64> = array![0.0, 0.0, 0.0];
let r: f64 = y.dot(&y).sqrt();
let hyper = ConstantCurvature::new(3, -1.0);
let d = hyper.distance(origin.view(), y.view()).expect("hyper d");
assert!((d - 2.0 * r.atanh()).abs() <= 1e-14, "poincare radial: {d}");
let sphere = ConstantCurvature::new(3, 1.0);
let d = sphere.distance(origin.view(), y.view()).expect("sphere d");
assert!((d - 2.0 * r.atan()).abs() <= 1e-14, "sphere radial: {d}");
let flat = ConstantCurvature::new(3, 0.0);
let x = array![0.4, 0.1, -0.7];
let d = flat.distance(x.view(), y.view()).expect("flat d");
let diff = (&y - &x).dot(&(&y - &x)).sqrt();
assert!((d - 2.0 * diff).abs() <= 1e-14, "flat doubled gauge: {d}");
let v = array![0.2, -0.5, 0.3];
let e = flat.exp_map(x.view(), v.view()).expect("flat exp");
for i in 0..3 {
assert!(
(e[i] - (x[i] + v[i])).abs() <= 1e-14,
"flat exp is translation"
);
}
let l = flat.log_map(x.view(), y.view()).expect("flat log");
for i in 0..3 {
assert!(
(l[i] - (y[i] - x[i])).abs() <= 1e-14,
"flat log is difference"
);
}
}
#[test]
fn exp_log_distance_are_mutually_consistent_across_kappa() {
let x = array![0.25, -0.1];
let v = array![0.15, 0.2];
for &kappa in &[-1.7, -0.6, -1e-7, 0.0, 1e-7, 0.8, 2.3] {
let m = ConstantCurvature::new(2, kappa);
let lam = m.conformal_factor(x.view()).expect("lambda");
let y = m.exp_map(x.view(), v.view()).expect("exp");
let d = m.distance(x.view(), y.view()).expect("dist");
let want = lam * v.dot(&v).sqrt();
assert!(
(d - want).abs() <= 1e-12 * want.max(1.0),
"κ={kappa}: d(x, exp_x v) = {d}, λ_x‖v‖ = {want}"
);
let back = m.log_map(x.view(), y.view()).expect("log");
for i in 0..2 {
assert!(
(back[i] - v[i]).abs() <= 1e-11,
"κ={kappa}: log∘exp ≠ id at [{i}]: {} vs {}",
back[i],
v[i]
);
}
}
}
#[test]
fn parallel_transport_preserves_riemannian_norm() {
let path = ndarray::arr2(&[[0.05, 0.1], [0.2, -0.15], [-0.1, 0.25]]);
let v = array![0.3, -0.4];
for &kappa in &[-1.2, 0.0, 1.4] {
let m = ConstantCurvature::new(2, kappa);
let out = m.parallel_transport(path.view(), v.view()).expect("pt");
let lam_a = m.conformal_factor(path.row(0)).expect("λ_a");
let lam_b = m.conformal_factor(path.row(2)).expect("λ_b");
let n_in = lam_a * v.dot(&v).sqrt();
let n_out = lam_b * out.dot(&out).sqrt();
assert!(
(n_in - n_out).abs() <= 1e-11 * n_in.max(1.0),
"κ={kappa}: transport norm {n_out} vs {n_in}"
);
}
}
#[test]
fn kappa_jets_match_finite_differences() {
let x = array![0.3, -0.15];
let y = array![-0.2, 0.25];
let h = 1e-5;
for &kappa in &[-1.3, -0.5, -1e-6, 0.0, 1e-6, 0.4, 1.6] {
let m = ConstantCurvature::new(2, kappa);
let up = ConstantCurvature::new(2, kappa + h);
let dn = ConstantCurvature::new(2, kappa - h);
let (d, d_k, d_kk) = distance_kappa_jet(&m, x.view(), y.view()).expect("jet");
let d_up = up.distance(x.view(), y.view()).expect("d+");
let d_dn = dn.distance(x.view(), y.view()).expect("d-");
let d_at = m.distance(x.view(), y.view()).expect("d0");
assert!(
(d - d_at).abs() <= 1e-13 * d_at.max(1.0),
"jet value channel"
);
let fd1 = (d_up - d_dn) / (2.0 * h);
let fd2 = (d_up - 2.0 * d_at + d_dn) / (h * h);
assert!(
(d_k - fd1).abs() <= 1e-6 * fd1.abs().max(1.0),
"κ={kappa}: ∂d/∂κ analytic {d_k} fd {fd1}"
);
assert!(
(d_kk - fd2).abs() <= 1e-4 * fd2.abs().max(1.0),
"κ={kappa}: ∂²d/∂κ² analytic {d_kk} fd {fd2}"
);
let (l, l_k, l_kk) = log_map_kappa_jet(&m, x.view(), y.view()).expect("ljet");
let l_up = up.log_map(x.view(), y.view()).expect("l+");
let l_dn = dn.log_map(x.view(), y.view()).expect("l-");
let l_at = m.log_map(x.view(), y.view()).expect("l0");
for i in 0..2 {
assert!((l[i] - l_at[i]).abs() <= 1e-13 * l_at[i].abs().max(1.0));
let fd1 = (l_up[i] - l_dn[i]) / (2.0 * h);
let fd2 = (l_up[i] - 2.0 * l_at[i] + l_dn[i]) / (h * h);
assert!(
(l_k[i] - fd1).abs() <= 1e-6 * fd1.abs().max(1.0),
"κ={kappa}: ∂log/∂κ[{i}] analytic {} fd {fd1}",
l_k[i]
);
assert!(
(l_kk[i] - fd2).abs() <= 1e-4 * fd2.abs().max(1.0),
"κ={kappa}: ∂²log/∂κ²[{i}] analytic {} fd {fd2}",
l_kk[i]
);
}
let v = array![0.12, -0.08];
let (e, e_k, e_kk) = exp_map_kappa_jet(&m, x.view(), v.view()).expect("ejet");
let e_up = up.exp_map(x.view(), v.view()).expect("e+");
let e_dn = dn.exp_map(x.view(), v.view()).expect("e-");
let e_at = m.exp_map(x.view(), v.view()).expect("e0");
for i in 0..2 {
assert!(
(e[i] - e_at[i]).abs() <= 1e-13 * e_at[i].abs().max(1.0),
"κ={kappa}: exp-jet value channel[{i}] {} vs {}",
e[i],
e_at[i]
);
let fd1 = (e_up[i] - e_dn[i]) / (2.0 * h);
let fd2 = (e_up[i] - 2.0 * e_at[i] + e_dn[i]) / (h * h);
assert!(
(e_k[i] - fd1).abs() <= 1e-6 * fd1.abs().max(1.0),
"κ={kappa}: ∂exp/∂κ[{i}] analytic {} fd {fd1}",
e_k[i]
);
assert!(
(e_kk[i] - fd2).abs() <= 1e-4 * fd2.abs().max(1.0),
"κ={kappa}: ∂²exp/∂κ²[{i}] analytic {} fd {fd2}",
e_kk[i]
);
}
}
}
#[test]
fn exp_map_vjp_matches_finite_differences() {
let h = 1e-6;
let cases: &[(ndarray::Array1<f64>, ndarray::Array1<f64>, ndarray::Array1<f64>)] = &[
(
array![0.2, -0.1],
array![0.12, 0.08],
array![1.0, -0.5],
),
(
array![-0.15, 0.22],
array![-0.05, 0.11],
array![0.3, 0.7],
),
];
for &kappa in &[-1.3, -0.3, 0.0, 0.4, 1.1] {
let m = ConstantCurvature::new(2, kappa);
for (x, v, g) in cases {
let d = x.len();
let (x_bar, v_bar) = m
.exp_map_vjp(x.view(), v.view(), g.view())
.expect("exp_map_vjp");
for j in 0..d {
let mut xp = x.clone();
xp[j] += h;
let mut xn = x.clone();
xn[j] -= h;
let ep = m.exp_map(xp.view(), v.view()).expect("exp x+");
let en = m.exp_map(xn.view(), v.view()).expect("exp x-");
let xbar_fd = g.dot(&(&ep - &en)) / (2.0 * h);
assert!(
(x_bar[j] - xbar_fd).abs() <= 1e-5 * x_bar[j].abs().max(1.0),
"κ={kappa}: x̄[{j}] analytic {} fd {xbar_fd}",
x_bar[j]
);
let mut vp = v.clone();
vp[j] += h;
let mut vn = v.clone();
vn[j] -= h;
let ep = m.exp_map(x.view(), vp.view()).expect("exp v+");
let en = m.exp_map(x.view(), vn.view()).expect("exp v-");
let vbar_fd = g.dot(&(&ep - &en)) / (2.0 * h);
assert!(
(v_bar[j] - vbar_fd).abs() <= 1e-5 * v_bar[j].abs().max(1.0),
"κ={kappa}: v̄[{j}] analytic {} fd {vbar_fd}",
v_bar[j]
);
}
}
}
}
#[test]
fn sectional_curvature_is_kappa() {
let m = ConstantCurvature::new(3, -0.37);
let p = array![0.1, 0.0, -0.2];
let u = array![1.0, 0.0, 0.0];
let v = array![0.0, 1.0, 0.0];
let k = m
.sectional_curvature(p.view(), (u.view(), v.view()))
.expect("sectional");
assert!((k + 0.37).abs() <= 1e-15);
}
#[test]
fn christoffel_matches_fd_of_metric() {
let d = 2usize;
let x = array![0.22, -0.13];
let h = 1e-5;
for &kappa in &[-1.4, -0.5, 0.0, 0.7, 1.9] {
let m = ConstantCurvature::new(d, kappa);
let lam = m.conformal_factor(x.view()).expect("λ");
let g_inv_diag = 1.0 / (lam * lam);
let mut dg: Vec<Array2<f64>> = Vec::with_capacity(d);
for a in 0..d {
let mut xp = x.clone();
xp[a] += h;
let mut xn = x.clone();
xn[a] -= h;
let gp = m.metric_tensor(xp.view()).expect("g+");
let gn = m.metric_tensor(xn.view()).expect("g-");
dg.push((&gp - &gn).mapv(|v| v / (2.0 * h)));
}
let gamma = m.christoffel_symbols(x.view()).expect("Γ");
for k in 0..d {
for i in 0..d {
for j in 0..d {
let expected =
0.5 * g_inv_diag * (dg[i][[j, k]] + dg[j][[i, k]] - dg[k][[i, j]]);
assert!(
(gamma[k][[i, j]] - expected).abs() <= 1e-6 * expected.abs().max(1.0),
"κ={kappa}: Γ^{k}_{{{i}{j}}} analytic {} vs FD-metric {expected}",
gamma[k][[i, j]]
);
}
}
}
}
}
}