use ndarray::{Array1, ArrayView1};
use num_traits::{Float, FromPrimitive, Zero};
pub struct LorentzModel<T> {
pub c: T,
}
impl<T> LorentzModel<T>
where
T: Float + FromPrimitive + Zero + ndarray::ScalarOperand + ndarray::LinalgScalar,
{
pub fn new(c: T) -> Self {
assert!(c > T::zero(), "curvature must be positive");
Self { c }
}
pub fn minkowski_dot(&self, x: &ArrayView1<T>, y: &ArrayView1<T>) -> T {
assert!(x.len() == y.len() && x.len() >= 2);
-x[0] * y[0] + x.slice(ndarray::s![1..]).dot(&y.slice(ndarray::s![1..]))
}
pub fn distance(&self, x: &ArrayView1<T>, y: &ArrayView1<T>) -> T {
let inner = self.minkowski_dot(x, y);
let arg = -self.c * inner;
let one = T::one();
let two = T::from_f64(2.0).unwrap();
let epsilon = T::from_f64(1e-7).unwrap();
if arg < one + epsilon {
if arg <= one {
return T::zero();
}
let x = arg - one;
return (two * x).sqrt() / self.c.sqrt();
}
arg.acosh() / self.c.sqrt()
}
pub fn is_on_manifold(&self, x: &ArrayView1<T>, tol: T) -> bool {
let inner = self.minkowski_dot(x, x);
(inner + T::one() / self.c).abs() < tol
}
pub fn project(&self, x: &ArrayView1<T>) -> Array1<T> {
let space_norm_sq = x.slice(ndarray::s![1..]).dot(&x.slice(ndarray::s![1..]));
let t = (space_norm_sq + T::one() / self.c).sqrt();
let mut result = x.to_owned();
result[0] = t;
result
}
pub fn from_euclidean(&self, v: &ArrayView1<T>) -> Array1<T> {
let space_norm_sq = v.dot(v);
let t = (space_norm_sq + T::one() / self.c).sqrt();
let mut result = Array1::zeros(v.len() + 1);
result[0] = t;
for (i, &val) in v.iter().enumerate() {
result[i + 1] = val;
}
result
}
pub fn to_euclidean(&self, x: &ArrayView1<T>) -> Array1<T> {
x.slice(ndarray::s![1..]).to_owned()
}
pub fn exp_map(&self, x: &ArrayView1<T>, v: &ArrayView1<T>) -> Array1<T> {
let v_norm_sq = self.minkowski_dot(v, v);
let epsilon = T::from_f64(1e-15).unwrap();
if v_norm_sq < epsilon {
return x.to_owned();
}
let v_norm = v_norm_sq.sqrt();
let c_sqrt = self.c.sqrt();
let cosh_term = (c_sqrt * v_norm).cosh();
let sinh_term = (c_sqrt * v_norm).sinh() / (c_sqrt * v_norm);
let term1 = x.mapv(|val| val * cosh_term);
let term2 = v.mapv(|val| val * sinh_term);
term1 + term2
}
pub fn log_map(&self, x: &ArrayView1<T>, y: &ArrayView1<T>) -> Array1<T> {
let inner = self.minkowski_dot(x, y);
let d = self.distance(x, y);
let epsilon = T::from_f64(1e-15).unwrap();
if d < epsilon {
return Array1::zeros(x.len());
}
let term_x = x.mapv(|val| val * self.c * inner);
let v = y.to_owned() + term_x;
let v_norm_sq = self.minkowski_dot(&v.view(), &v.view());
if v_norm_sq < epsilon {
return Array1::zeros(x.len());
}
let v_norm = v_norm_sq.sqrt();
v * (d / v_norm)
}
pub fn parallel_transport(
&self,
x: &ArrayView1<T>,
y: &ArrayView1<T>,
v: &ArrayView1<T>,
) -> Array1<T> {
let inner_xy = self.minkowski_dot(x, y);
let inner_vy = self.minkowski_dot(v, y);
let one = T::one();
let denom = one - self.c * inner_xy;
let epsilon = T::from_f64(1e-15).unwrap();
if denom.abs() < epsilon {
return v.to_owned();
}
let coeff = self.c * inner_vy / denom;
let sum_xy = x.to_owned() + y;
v.to_owned() - sum_xy.mapv(|val| val * coeff)
}
pub fn origin(&self, dim: usize) -> Array1<T> {
let mut o = Array1::zeros(dim + 1);
o[0] = T::one() / self.c.sqrt();
o
}
}
pub mod conversions {
use super::*;
use crate::PoincareBall;
pub fn poincare_to_lorentz<T>(ball: &PoincareBall<T>, x: &ArrayView1<T>) -> Array1<T>
where
T: Float + FromPrimitive + Zero + ndarray::ScalarOperand,
{
let c = ball.c;
let x_norm_sq = x.dot(x);
let one = T::one();
let two = T::from_f64(2.0).unwrap();
let denom = one - c * x_norm_sq;
let c_sqrt = c.sqrt();
let t = (one + c * x_norm_sq) / (denom * c_sqrt);
let mut result = Array1::zeros(x.len() + 1);
result[0] = t;
let scale = two / (denom * c_sqrt);
for (i, &val) in x.iter().enumerate() {
result[i + 1] = val * scale;
}
result
}
pub fn lorentz_to_poincare<T>(lorentz: &LorentzModel<T>, x: &ArrayView1<T>) -> Array1<T>
where
T: Float + FromPrimitive + Zero + ndarray::ScalarOperand,
{
let c_sqrt = lorentz.c.sqrt();
let one = T::one();
let denom = x[0] * c_sqrt + one;
let mut result = Array1::zeros(x.len() - 1);
for i in 1..x.len() {
result[i - 1] = x[i] / denom;
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use ndarray::array;
const TOL: f64 = 1e-10;
#[test]
fn test_origin_on_manifold() {
let lorentz = LorentzModel::new(1.0);
let o = lorentz.origin(3);
assert!(lorentz.is_on_manifold(&o.view(), TOL));
}
#[test]
fn test_from_euclidean_on_manifold() {
let lorentz = LorentzModel::new(1.0);
let v = array![0.5, -0.3, 0.2];
let x = lorentz.from_euclidean(&v.view());
assert!(lorentz.is_on_manifold(&x.view(), TOL));
}
#[test]
fn test_distance_self_zero() {
let lorentz = LorentzModel::new(1.0);
let x = lorentz.from_euclidean(&array![0.5, 0.3].view());
let d = lorentz.distance(&x.view(), &x.view());
assert!(d.abs() < TOL);
}
#[test]
fn test_distance_symmetric() {
let lorentz = LorentzModel::new(1.0);
let x = lorentz.from_euclidean(&array![0.5, 0.3].view());
let y = lorentz.from_euclidean(&array![-0.2, 0.4].view());
let d_xy = lorentz.distance(&x.view(), &y.view());
let d_yx = lorentz.distance(&y.view(), &x.view());
assert_relative_eq!(d_xy, d_yx, epsilon = TOL);
}
#[test]
fn test_distance_triangle_inequality() {
let lorentz = LorentzModel::new(1.0);
let a = lorentz.from_euclidean(&array![0.3, 0.0].view());
let b = lorentz.from_euclidean(&array![0.0, 0.3].view());
let c = lorentz.from_euclidean(&array![-0.3, 0.0].view());
let d_ac = lorentz.distance(&a.view(), &c.view());
let d_ab = lorentz.distance(&a.view(), &b.view());
let d_bc = lorentz.distance(&b.view(), &c.view());
assert!(d_ac <= d_ab + d_bc + TOL);
}
#[test]
fn test_exp_log_inverse() {
let lorentz = LorentzModel::new(1.0);
let x = lorentz.origin(3);
let v = array![0.0, 0.3, 0.2, 0.1];
let y = lorentz.exp_map(&x.view(), &v.view());
assert!(lorentz.is_on_manifold(&y.view(), 1e-6));
let v_recovered = lorentz.log_map(&x.view(), &y.view());
for i in 0..v.len() {
assert_relative_eq!(v[i], v_recovered[i], epsilon = 1e-6);
}
}
#[test]
fn test_project_preserves_direction() {
let lorentz = LorentzModel::new(1.0);
let x = array![2.0, 0.5, 0.3];
let projected = lorentz.project(&x.view());
assert!(lorentz.is_on_manifold(&projected.view(), TOL));
assert_relative_eq!(x[1] / x[2], projected[1] / projected[2], epsilon = TOL);
}
#[test]
fn test_conversion_round_trip() {
use conversions::*;
let ball = crate::PoincareBall::new(1.0);
let lorentz = LorentzModel::new(1.0);
let p = array![0.3, 0.2, -0.1];
let l = poincare_to_lorentz(&ball, &p.view());
let p_back = lorentz_to_poincare(&lorentz, &l.view());
for i in 0..p.len() {
assert_relative_eq!(p[i], p_back[i], epsilon = 1e-10);
}
}
fn make_tangent(lorentz: &LorentzModel<f64>, x: &Array1<f64>, space: &[f64]) -> Array1<f64> {
let dot_space: f64 = space.iter().zip(x.iter().skip(1)).map(|(a, b)| a * b).sum();
let v0 = dot_space / x[0];
let mut v = Array1::zeros(x.len());
v[0] = v0;
for (i, &s) in space.iter().enumerate() {
v[i + 1] = s;
}
let check = lorentz.minkowski_dot(&v.view(), &x.view());
assert!(check.abs() < 1e-10, "tangent constraint violated: {check}");
v
}
#[test]
fn test_parallel_transport_preserves_norm() {
let lorentz = LorentzModel::new(1.0);
let x = lorentz.from_euclidean(&array![0.1, 0.05].view());
let y = lorentz.from_euclidean(&array![0.12, 0.08].view());
let v = make_tangent(&lorentz, &x, &[0.5, -0.3]);
let pt = lorentz.parallel_transport(&x.view(), &y.view(), &v.view());
let norm_v = lorentz.minkowski_dot(&v.view(), &v.view());
let norm_pt = lorentz.minkowski_dot(&pt.view(), &pt.view());
assert!(
(norm_v - norm_pt).abs() < 1e-4,
"PT should preserve norm: {norm_v} vs {norm_pt}"
);
}
#[test]
fn test_parallel_transport_identity_when_same() {
let lorentz = LorentzModel::new(1.0);
let x = lorentz.from_euclidean(&array![0.5, 0.3].view());
let v = make_tangent(&lorentz, &x, &[0.4, -0.2]);
let pt = lorentz.parallel_transport(&x.view(), &x.view(), &v.view());
for i in 0..v.len() {
assert_relative_eq!(v[i], pt[i], epsilon = 1e-10);
}
}
#[test]
fn test_from_to_euclidean_round_trip() {
let lorentz = LorentzModel::new(1.0);
let euc = array![0.5, -0.3, 0.2];
let hyp = lorentz.from_euclidean(&euc.view());
let euc_back = lorentz.to_euclidean(&hyp.view());
for i in 0..euc.len() {
assert_relative_eq!(euc[i], euc_back[i], epsilon = 1e-10);
}
}
#[test]
fn test_exp_map_stays_on_manifold() {
let lorentz = LorentzModel::new(1.0);
let x = lorentz.from_euclidean(&array![0.3, 0.2].view());
let v = make_tangent(&lorentz, &x, &[5.0, -3.0]);
let y = lorentz.exp_map(&x.view(), &v.view());
assert!(
lorentz.is_on_manifold(&y.view(), 1e-4),
"exp_map result not on manifold"
);
}
#[test]
fn test_distance_nonneg() {
let lorentz = LorentzModel::new(1.0);
let points = [
array![0.5, 0.3],
array![-0.2, 0.4],
array![1.0, 1.0],
array![0.0, 0.0],
];
for i in 0..points.len() {
for j in i..points.len() {
let xi = lorentz.from_euclidean(&points[i].view());
let xj = lorentz.from_euclidean(&points[j].view());
let d = lorentz.distance(&xi.view(), &xj.view());
assert!(d >= -1e-10, "negative distance: {d}");
}
}
}
#[test]
fn test_different_curvatures() {
let l1 = LorentzModel::new(1.0);
let l2 = LorentzModel::new(4.0);
let x1 = l1.from_euclidean(&array![0.5, 0.3].view());
let y1 = l1.from_euclidean(&array![-0.2, 0.4].view());
let x2 = l2.from_euclidean(&array![0.5, 0.3].view());
let y2 = l2.from_euclidean(&array![-0.2, 0.4].view());
let d1 = l1.distance(&x1.view(), &y1.view());
let d2 = l2.distance(&x2.view(), &y2.view());
assert!((d1 - d2).abs() > 1e-6);
}
}