use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::geometry::manifold::{GeometryError, GeometryResult};
pub const BOUNDARY_EPS: f64 = 1.0e-5;
pub const ORIGIN_EPS: f64 = 1.0e-15;
fn require_negative_curvature(curvature: f64) -> GeometryResult<f64> {
if !(curvature < 0.0) || !curvature.is_finite() {
return Err(GeometryError::InvalidPoint(
"Poincaré curvature must be a finite c < 0",
));
}
Ok((-curvature).sqrt())
}
fn check_same_len(a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> GeometryResult<()> {
if a.len() != b.len() {
return Err(GeometryError::DimensionMismatch {
context: "Poincaré vector",
expected: a.len(),
got: b.len(),
});
}
if a.is_empty() {
return Err(GeometryError::InvalidPoint(
"Poincaré vector must have at least one component",
));
}
Ok(())
}
fn dot(a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> f64 {
let mut acc = 0.0;
for i in 0..a.len() {
acc += a[i] * b[i];
}
acc
}
pub fn project_into_ball(
point: ArrayView1<'_, f64>,
curvature: f64,
) -> GeometryResult<Array1<f64>> {
let sqrt_negc = require_negative_curvature(curvature)?;
let mut out = point.to_owned();
let norm = out.iter().map(|v| v * v).sum::<f64>().sqrt();
let max_norm = (1.0 - BOUNDARY_EPS) / sqrt_negc;
if norm.is_finite() && norm > max_norm && norm > ORIGIN_EPS {
let scale = max_norm / norm;
for v in out.iter_mut() {
*v *= scale;
}
}
Ok(out)
}
pub fn mobius_add(
u: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
curvature: f64,
) -> GeometryResult<Array1<f64>> {
check_same_len(u, v)?;
require_negative_curvature(curvature)?;
let k = -curvature;
let uv = dot(u, v);
let uu = dot(u, u);
let vv = dot(v, v);
let coeff_u = 1.0 + 2.0 * k * uv + k * vv;
let coeff_v = 1.0 - k * uu;
let denom = (1.0 + 2.0 * k * uv + k * k * uu * vv).max(ORIGIN_EPS);
let mut out = Array1::<f64>::zeros(u.len());
for i in 0..u.len() {
out[i] = (coeff_u * u[i] + coeff_v * v[i]) / denom;
}
Ok(out)
}
pub fn poincare_distance(
a: ArrayView1<'_, f64>,
b: ArrayView1<'_, f64>,
curvature: f64,
) -> GeometryResult<f64> {
check_same_len(a, b)?;
let sqrt_negc = require_negative_curvature(curvature)?;
let mut diff_sq = 0.0;
let mut a_sq = 0.0;
let mut b_sq = 0.0;
for i in 0..a.len() {
let d = a[i] - b[i];
diff_sq += d * d;
a_sq += a[i] * a[i];
b_sq += b[i] * b[i];
}
let denom_a = (1.0 + curvature * a_sq).max(ORIGIN_EPS);
let denom_b = (1.0 + curvature * b_sq).max(ORIGIN_EPS);
let arg = 1.0 + 2.0 * (-curvature) * diff_sq / (denom_a * denom_b);
let arg = arg.max(1.0 + ORIGIN_EPS);
Ok(arg.acosh() / sqrt_negc)
}
pub fn log_origin(y: ArrayView1<'_, f64>, curvature: f64) -> GeometryResult<Array1<f64>> {
let sqrt_negc = require_negative_curvature(curvature)?;
let mut out = y.to_owned();
let norm = y.iter().map(|v| v * v).sum::<f64>().sqrt();
if norm <= ORIGIN_EPS {
return Ok(out);
}
let arg = (sqrt_negc * norm).min(1.0 - BOUNDARY_EPS);
let coeff = arg.atanh() / (sqrt_negc * norm);
for v in out.iter_mut() {
*v *= coeff;
}
Ok(out)
}
pub fn exp_origin(v: ArrayView1<'_, f64>, curvature: f64) -> GeometryResult<Array1<f64>> {
let sqrt_negc = require_negative_curvature(curvature)?;
let mut out = v.to_owned();
let norm = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm <= ORIGIN_EPS {
return Ok(out);
}
let s = sqrt_negc * norm;
let coeff = s.tanh() / s;
for x in out.iter_mut() {
*x *= coeff;
}
Ok(out)
}
#[derive(Debug, Clone)]
pub struct TangentDecodeCache {
pub tangents: Array2<f64>,
pub v: Array2<f64>,
pub atoms_projected: Array2<f64>,
pub gates: Array2<f64>,
pub curvature: f64,
}
fn check_atoms_shape(atoms: ArrayView2<'_, f64>, gates: ArrayView2<'_, f64>) -> GeometryResult<()> {
let (f_atoms, d) = atoms.dim();
let f_gates = gates.dim().1;
if f_atoms == 0 || d == 0 {
return Err(GeometryError::InvalidPoint(
"Poincaré atoms must have F>0 and ball_dim>0",
));
}
if f_atoms != f_gates {
return Err(GeometryError::DimensionMismatch {
context: "Poincaré decoder atom count",
expected: f_atoms,
got: f_gates,
});
}
Ok(())
}
fn project_and_log(
atoms: ArrayView2<'_, f64>,
curvature: f64,
) -> GeometryResult<(Array2<f64>, Array2<f64>)> {
let sqrt_negc = require_negative_curvature(curvature)?;
let max_norm = (1.0 - BOUNDARY_EPS) / sqrt_negc;
let (f_atoms, d) = atoms.dim();
let mut projected = Array2::<f64>::zeros((f_atoms, d));
let mut tangents = Array2::<f64>::zeros((f_atoms, d));
for f in 0..f_atoms {
let row = atoms.row(f);
let nrm = row.iter().map(|v| v * v).sum::<f64>().sqrt();
let scale = if nrm.is_finite() && nrm > max_norm && nrm > ORIGIN_EPS {
max_norm / nrm
} else {
1.0
};
for i in 0..d {
projected[[f, i]] = row[i] * scale;
}
let nrm_proj = (0..d)
.map(|i| projected[[f, i]] * projected[[f, i]])
.sum::<f64>()
.sqrt();
if nrm_proj <= ORIGIN_EPS {
continue;
}
let arg = (sqrt_negc * nrm_proj).min(1.0 - BOUNDARY_EPS);
let coeff = arg.atanh() / (sqrt_negc * nrm_proj);
for i in 0..d {
tangents[[f, i]] = coeff * projected[[f, i]];
}
}
Ok((projected, tangents))
}
pub fn tangent_decode_forward(
atoms: ArrayView2<'_, f64>,
gates: ArrayView2<'_, f64>,
curvature: f64,
) -> GeometryResult<(Array2<f64>, TangentDecodeCache)> {
check_atoms_shape(atoms, gates)?;
let sqrt_negc = require_negative_curvature(curvature)?;
let (projected, tangents) = project_and_log(atoms, curvature)?;
let v = gates.dot(&tangents);
let (batch, d) = v.dim();
let mut x_hat = Array2::<f64>::zeros((batch, d));
for b in 0..batch {
let nrm = (0..d).map(|i| v[[b, i]] * v[[b, i]]).sum::<f64>().sqrt();
if nrm <= ORIGIN_EPS {
continue;
}
let s = sqrt_negc * nrm;
let coeff = s.tanh() / s;
for i in 0..d {
x_hat[[b, i]] = coeff * v[[b, i]];
}
}
let cache = TangentDecodeCache {
tangents,
v,
atoms_projected: projected,
gates: gates.to_owned(),
curvature,
};
Ok((x_hat, cache))
}
pub fn tangent_decode_backward(
cache: &TangentDecodeCache,
grad_x_hat: ArrayView2<'_, f64>,
) -> GeometryResult<(Array2<f64>, Array2<f64>)> {
let sqrt_negc = require_negative_curvature(cache.curvature)?;
let (batch, d) = cache.v.dim();
let n_atoms = cache.tangents.dim().0;
if grad_x_hat.dim() != (batch, d) {
return Err(GeometryError::DimensionMismatch {
context: "Poincaré tangent_decode_backward grad",
expected: batch * d,
got: grad_x_hat.dim().0 * grad_x_hat.dim().1,
});
}
let mut grad_v = Array2::<f64>::zeros((batch, d));
for b in 0..batch {
let v_row = cache.v.row(b);
let g_row = grad_x_hat.row(b);
let nrm_sq: f64 = (0..d).map(|i| v_row[i] * v_row[i]).sum();
let nrm = nrm_sq.sqrt();
if nrm <= ORIGIN_EPS {
for i in 0..d {
grad_v[[b, i]] = g_row[i];
}
continue;
}
let s = sqrt_negc * nrm;
let tanh_s = s.tanh();
let phi = tanh_s / s;
let phi_prime = (s * (1.0 - tanh_s * tanh_s) - tanh_s) / (s * s);
let dphi_dv_coeff = phi_prime * sqrt_negc / nrm;
let g_dot_v: f64 = (0..d).map(|i| g_row[i] * v_row[i]).sum();
for j in 0..d {
grad_v[[b, j]] = phi * g_row[j] + g_dot_v * dphi_dv_coeff * v_row[j];
}
}
let grad_gates = grad_v.dot(&cache.tangents.t());
let grad_tangents = cache.gates.t().dot(&grad_v);
let mut grad_atoms_proj = Array2::<f64>::zeros((n_atoms, d));
for f in 0..n_atoms {
let a_row = cache.atoms_projected.row(f);
let g_l_row = grad_tangents.row(f);
let r_sq: f64 = (0..d).map(|i| a_row[i] * a_row[i]).sum();
let r = r_sq.sqrt();
if r <= ORIGIN_EPS {
for i in 0..d {
grad_atoms_proj[[f, i]] = g_l_row[i];
}
continue;
}
let t = (sqrt_negc * r).min(1.0 - BOUNDARY_EPS);
let psi = t.atanh() / t;
let psi_prime = (t / (1.0 - t * t) - t.atanh()) / (t * t);
let dpsi_da_coeff = psi_prime * sqrt_negc / r;
let g_l_dot_a: f64 = (0..d).map(|i| g_l_row[i] * a_row[i]).sum();
for j in 0..d {
grad_atoms_proj[[f, j]] = psi * g_l_row[j] + g_l_dot_a * dpsi_da_coeff * a_row[j];
}
}
let grad_atoms = grad_atoms_proj;
Ok((grad_gates, grad_atoms))
}
pub fn to_lorentz(y: ArrayView1<'_, f64>, curvature: f64) -> GeometryResult<Array1<f64>> {
let sqrt_negc = require_negative_curvature(curvature)?;
let d = y.len();
if d == 0 {
return Err(GeometryError::InvalidPoint("to_lorentz requires d >= 1"));
}
let yhat_sq: f64 = y.iter().map(|v| (sqrt_negc * v).powi(2)).sum();
let denom = (1.0 - yhat_sq).max(ORIGIN_EPS);
let z0 = (1.0 + yhat_sq) / denom;
let mut out = Array1::<f64>::zeros(d + 1);
out[0] = z0 / sqrt_negc;
for i in 0..d {
out[i + 1] = (2.0 * sqrt_negc * y[i] / denom) / sqrt_negc;
}
Ok(out)
}
pub fn from_lorentz(x: ArrayView1<'_, f64>, curvature: f64) -> GeometryResult<Array1<f64>> {
let sqrt_negc = require_negative_curvature(curvature)?;
if x.len() < 2 {
return Err(GeometryError::InvalidPoint(
"from_lorentz requires d+1 >= 2",
));
}
let d = x.len() - 1;
let x0_scaled = x[0] * sqrt_negc;
let denom = (x0_scaled + 1.0).max(ORIGIN_EPS);
let mut out = Array1::<f64>::zeros(d);
for i in 0..d {
let xs_scaled = x[i + 1] * sqrt_negc;
out[i] = (xs_scaled / denom) / sqrt_negc;
}
Ok(out)
}
pub fn lorentz_log_origin(x: ArrayView1<'_, f64>, curvature: f64) -> GeometryResult<Array1<f64>> {
let sqrt_negc = require_negative_curvature(curvature)?;
if x.len() < 2 {
return Err(GeometryError::InvalidPoint(
"lorentz_log_origin requires d+1 >= 2",
));
}
let d = x.len() - 1;
let x0 = x[0];
let arg = (sqrt_negc * x0).max(1.0 + ORIGIN_EPS);
let dist = arg.acosh() / sqrt_negc;
let mut xs_norm_sq = 0.0;
for i in 0..d {
xs_norm_sq += x[i + 1] * x[i + 1];
}
let xs_norm = xs_norm_sq.sqrt().max(ORIGIN_EPS);
let mut out = Array1::<f64>::zeros(d);
for i in 0..d {
out[i] = dist * x[i + 1] / xs_norm;
}
Ok(out)
}
pub fn lorentz_exp_origin(
v_spatial: ArrayView1<'_, f64>,
curvature: f64,
) -> GeometryResult<Array1<f64>> {
let sqrt_negc = require_negative_curvature(curvature)?;
let d = v_spatial.len();
let norm_sq: f64 = v_spatial.iter().map(|x| x * x).sum();
let norm = norm_sq.sqrt().max(ORIGIN_EPS);
let s = sqrt_negc * norm;
let mut out = Array1::<f64>::zeros(d + 1);
out[0] = s.cosh() / sqrt_negc;
let coeff = s.sinh() / s;
for i in 0..d {
out[i + 1] = coeff * v_spatial[i];
}
Ok(out)
}
pub fn lorentz_decode_forward(
atoms: ArrayView2<'_, f64>,
gates: ArrayView2<'_, f64>,
curvature: f64,
) -> GeometryResult<Array2<f64>> {
check_atoms_shape(atoms, gates)?;
let sqrt_negc = require_negative_curvature(curvature)?;
let (f_atoms, d) = atoms.dim();
let batch = gates.dim().0;
let mut tangents = Array2::<f64>::zeros((f_atoms, d));
let max_norm = (1.0 - BOUNDARY_EPS) / sqrt_negc;
for f in 0..f_atoms {
let row = atoms.row(f);
let nrm = row.iter().map(|v| v * v).sum::<f64>().sqrt();
let scale = if nrm.is_finite() && nrm > max_norm && nrm > ORIGIN_EPS {
max_norm / nrm
} else {
1.0
};
let mut a_proj = Array1::<f64>::zeros(d);
for i in 0..d {
a_proj[i] = row[i] * scale;
}
let x_h = to_lorentz(a_proj.view(), curvature)?;
let log = lorentz_log_origin(x_h.view(), curvature)?;
for i in 0..d {
tangents[[f, i]] = log[i];
}
}
let v = gates.dot(&tangents);
let mut out = Array2::<f64>::zeros((batch, d));
for b in 0..batch {
let v_row: Array1<f64> = v.row(b).to_owned();
let x_h = lorentz_exp_origin(v_row.view(), curvature)?;
let y = from_lorentz(x_h.view(), curvature)?;
for i in 0..d {
out[[b, i]] = y[i];
}
}
Ok(out)
}
pub fn lorentz_decode_backward(
cache: &TangentDecodeCache,
grad_x_hat: ArrayView2<'_, f64>,
) -> GeometryResult<(Array2<f64>, Array2<f64>)> {
tangent_decode_backward(cache, grad_x_hat)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
const TOL: f64 = 1.0e-10;
#[test]
fn distance_self_is_zero() {
let a = array![0.1, -0.2, 0.05];
let d = poincare_distance(a.view(), a.view(), -1.0).expect("distance");
assert!(d.abs() < 1.0e-8, "got {d}");
}
#[test]
fn mobius_add_zero_is_identity_on_either_side() {
let v = array![0.2, -0.1, 0.05];
let zero = Array1::<f64>::zeros(3);
let left = mobius_add(zero.view(), v.view(), -1.0).expect("0+v");
let right = mobius_add(v.view(), zero.view(), -1.0).expect("v+0");
for i in 0..3 {
assert!((left[i] - v[i]).abs() < TOL, "left mismatch at {i}");
assert!((right[i] - v[i]).abs() < TOL, "right mismatch at {i}");
}
}
#[test]
fn distance_matches_textbook_formula_unit_curvature() {
let a = array![0.3_f64, 0.1];
let b = array![-0.2_f64, 0.4];
let diff_sq: f64 = (0..2).map(|i| (a[i] - b[i]).powi(2)).sum();
let a_sq: f64 = a.iter().map(|v| v * v).sum();
let b_sq: f64 = b.iter().map(|v| v * v).sum();
let expected = (1.0 + 2.0 * diff_sq / ((1.0 - a_sq) * (1.0 - b_sq))).acosh();
let got = poincare_distance(a.view(), b.view(), -1.0).expect("distance");
assert!(
(got - expected).abs() < 1.0e-12,
"got {got}, expected {expected}"
);
}
#[test]
fn log_exp_origin_round_trips() {
let y = array![0.2, -0.15, 0.05, 0.1];
let v = log_origin(y.view(), -1.0).expect("log");
let back = exp_origin(v.view(), -1.0).expect("exp");
for i in 0..4 {
assert!(
(back[i] - y[i]).abs() < 1.0e-12,
"round trip mismatch at {i}: {} vs {}",
back[i],
y[i]
);
}
}
#[test]
fn project_into_ball_clamps_near_boundary() {
let raw = array![0.999, 0.0];
let proj = project_into_ball(raw.view(), -1.0).expect("project");
let norm = (proj[0] * proj[0] + proj[1] * proj[1]).sqrt();
assert!(norm < 1.0, "norm {} should be inside ball", norm);
assert!(norm <= 1.0 - BOUNDARY_EPS + 1e-12);
}
#[test]
fn tangent_decode_collapses_to_linear_in_small_input_limit() {
let atoms = array![[0.001, 0.0, 0.0], [0.0, -0.001, 0.0]];
let gates = array![[0.5, -0.3]];
let (x_hat, cache) =
tangent_decode_forward(atoms.view(), gates.view(), -1.0).expect("forward");
assert_eq!(cache.tangents.dim(), (2, 3));
let linear = gates.dot(&atoms);
for i in 0..3 {
assert!(
(x_hat[[0, i]] - linear[[0, i]]).abs() < 1.0e-6,
"x_hat[{i}] = {} vs linear {}",
x_hat[[0, i]],
linear[[0, i]]
);
}
}
#[test]
fn poincare_and_lorentz_paths_agree_on_small_inputs() {
let atoms = array![
[0.05, 0.02, -0.01],
[-0.04, 0.03, 0.02],
[0.01, -0.02, 0.04],
];
let gates = array![[0.3, -0.2, 0.1], [-0.1, 0.4, 0.05]];
let (x_p, cache) =
tangent_decode_forward(atoms.view(), gates.view(), -1.0).expect("poincare forward");
assert_eq!(cache.tangents.dim(), (3, 3));
let x_l =
lorentz_decode_forward(atoms.view(), gates.view(), -1.0).expect("lorentz forward");
for b in 0..2 {
for i in 0..3 {
let diff = (x_p[[b, i]] - x_l[[b, i]]).abs();
assert!(
diff < 1.0e-5,
"p vs l mismatch at ({b},{i}): {} vs {}",
x_p[[b, i]],
x_l[[b, i]]
);
}
}
}
#[test]
fn tangent_backward_matches_finite_difference() {
let atoms = array![[0.05, 0.02], [-0.03, 0.04]];
let gates = array![[0.3, -0.2]];
let (x_hat, cache) =
tangent_decode_forward(atoms.view(), gates.view(), -1.0).expect("forward");
let mut grad_x = Array2::<f64>::zeros(x_hat.dim());
for i in 0..x_hat.dim().0 {
for j in 0..x_hat.dim().1 {
grad_x[[i, j]] = 2.0 * x_hat[[i, j]];
}
}
let (grad_gates, grad_atoms) =
tangent_decode_backward(&cache, grad_x.view()).expect("backward");
let eps = 1.0e-6;
let mut gates_p = gates.clone();
gates_p[[0, 0]] += eps;
let (x_p, _) = tangent_decode_forward(atoms.view(), gates_p.view(), -1.0).unwrap();
let mut gates_m = gates.clone();
gates_m[[0, 0]] -= eps;
let (x_m, _) = tangent_decode_forward(atoms.view(), gates_m.view(), -1.0).unwrap();
let loss_p: f64 = x_p.iter().map(|v| v * v).sum();
let loss_m: f64 = x_m.iter().map(|v| v * v).sum();
let fd_gate = (loss_p - loss_m) / (2.0 * eps);
assert!(
(fd_gate - grad_gates[[0, 0]]).abs() < 1.0e-5,
"gate grad: analytic {} vs FD {}",
grad_gates[[0, 0]],
fd_gate
);
let mut atoms_p = atoms.clone();
atoms_p[[1, 0]] += eps;
let (x_p2, _) = tangent_decode_forward(atoms_p.view(), gates.view(), -1.0).unwrap();
let mut atoms_m = atoms.clone();
atoms_m[[1, 0]] -= eps;
let (x_m2, _) = tangent_decode_forward(atoms_m.view(), gates.view(), -1.0).unwrap();
let lp: f64 = x_p2.iter().map(|v| v * v).sum();
let lm: f64 = x_m2.iter().map(|v| v * v).sum();
let fd_atom = (lp - lm) / (2.0 * eps);
assert!(
(fd_atom - grad_atoms[[1, 0]]).abs() < 1.0e-5,
"atom grad: analytic {} vs FD {}",
grad_atoms[[1, 0]],
fd_atom
);
}
#[test]
fn lorentz_backward_matches_finite_difference_of_lorentz_forward() {
let atoms = array![[0.05, 0.02], [-0.03, 0.04]];
let gates = array![[0.3, -0.2]];
let (x_hat_p, cache) = tangent_decode_forward(atoms.view(), gates.view(), -1.0)
.expect("poincare forward (for cache)");
let x_hat_l =
lorentz_decode_forward(atoms.view(), gates.view(), -1.0).expect("lorentz forward");
for b in 0..x_hat_l.dim().0 {
for i in 0..x_hat_l.dim().1 {
assert!(
(x_hat_p[[b, i]] - x_hat_l[[b, i]]).abs() < 1.0e-10,
"lorentz/poincare forward mismatch at ({b},{i})"
);
}
}
let mut grad_x = Array2::<f64>::zeros(x_hat_l.dim());
for i in 0..x_hat_l.dim().0 {
for j in 0..x_hat_l.dim().1 {
grad_x[[i, j]] = 2.0 * x_hat_l[[i, j]];
}
}
let (grad_gates, grad_atoms) =
lorentz_decode_backward(&cache, grad_x.view()).expect("lorentz backward");
let eps = 1.0e-6;
let mut gates_p = gates.clone();
gates_p[[0, 0]] += eps;
let x_p = lorentz_decode_forward(atoms.view(), gates_p.view(), -1.0).unwrap();
let mut gates_m = gates.clone();
gates_m[[0, 0]] -= eps;
let x_m = lorentz_decode_forward(atoms.view(), gates_m.view(), -1.0).unwrap();
let lp: f64 = x_p.iter().map(|v| v * v).sum();
let lm: f64 = x_m.iter().map(|v| v * v).sum();
let fd_gate = (lp - lm) / (2.0 * eps);
assert!(
(fd_gate - grad_gates[[0, 0]]).abs() < 1.0e-5,
"lorentz gate grad: analytic {} vs FD {}",
grad_gates[[0, 0]],
fd_gate
);
let mut atoms_p = atoms.clone();
atoms_p[[1, 0]] += eps;
let x_p2 = lorentz_decode_forward(atoms_p.view(), gates.view(), -1.0).unwrap();
let mut atoms_m = atoms.clone();
atoms_m[[1, 0]] -= eps;
let x_m2 = lorentz_decode_forward(atoms_m.view(), gates.view(), -1.0).unwrap();
let lp2: f64 = x_p2.iter().map(|v| v * v).sum();
let lm2: f64 = x_m2.iter().map(|v| v * v).sum();
let fd_atom = (lp2 - lm2) / (2.0 * eps);
assert!(
(fd_atom - grad_atoms[[1, 0]]).abs() < 1.0e-5,
"lorentz atom grad: analytic {} vs FD {}",
grad_atoms[[1, 0]],
fd_atom
);
}
#[test]
fn rejects_nonnegative_curvature() {
let v = array![0.1, 0.2];
assert!(log_origin(v.view(), 0.0).is_err());
assert!(log_origin(v.view(), 0.5).is_err());
assert!(mobius_add(v.view(), v.view(), 0.0).is_err());
}
}