#![cfg(feature = "ndarray")]
use ndarray::{Array1, ArrayView1};
use num_traits::{Float, FromPrimitive, Zero};
pub struct LorentzModel<T> {
pub c: T,
}
impl<T> LorentzModel<T>
where
T: Float + FromPrimitive + Zero + ndarray::ScalarOperand + ndarray::LinalgScalar,
{
pub fn new(c: T) -> Self {
assert!(c > T::zero(), "curvature must be positive");
Self { c }
}
pub fn minkowski_dot(&self, x: &ArrayView1<T>, y: &ArrayView1<T>) -> T {
assert!(x.len() == y.len() && x.len() >= 2);
-x[0] * y[0] + x.slice(ndarray::s![1..]).dot(&y.slice(ndarray::s![1..]))
}
pub fn distance(&self, x: &ArrayView1<T>, y: &ArrayView1<T>) -> T {
let inner = self.minkowski_dot(x, y);
let arg = -self.c * inner;
let one = T::one();
let two = T::from_f64(2.0).unwrap();
let epsilon = T::from_f64(1e-7).unwrap();
if arg < one + epsilon {
if arg <= one {
return T::zero();
}
let x = arg - one;
return (two * x).sqrt() / self.c.sqrt();
}
arg.acosh() / self.c.sqrt()
}
pub fn is_on_manifold(&self, x: &ArrayView1<T>, tol: T) -> bool {
let inner = self.minkowski_dot(x, x);
(inner + T::one() / self.c).abs() < tol
}
pub fn project(&self, x: &ArrayView1<T>) -> Array1<T> {
let space_norm_sq = x.slice(ndarray::s![1..]).dot(&x.slice(ndarray::s![1..]));
let t = (space_norm_sq + T::one() / self.c).sqrt();
let mut result = x.to_owned();
result[0] = t;
result
}
pub fn from_euclidean(&self, v: &ArrayView1<T>) -> Array1<T> {
let space_norm_sq = v.dot(v);
let t = (space_norm_sq + T::one() / self.c).sqrt();
let mut result = Array1::zeros(v.len() + 1);
result[0] = t;
for (i, &val) in v.iter().enumerate() {
result[i + 1] = val;
}
result
}
pub fn to_euclidean(&self, x: &ArrayView1<T>) -> Array1<T> {
x.slice(ndarray::s![1..]).to_owned()
}
pub fn exp_map(&self, x: &ArrayView1<T>, v: &ArrayView1<T>) -> Array1<T> {
let v_norm_sq = self.minkowski_dot(v, v);
let epsilon = T::from_f64(1e-15).unwrap();
if v_norm_sq < epsilon {
return x.to_owned();
}
let v_norm = v_norm_sq.sqrt();
let c_sqrt = self.c.sqrt();
let cosh_term = (c_sqrt * v_norm).cosh();
let sinh_term = (c_sqrt * v_norm).sinh() / (c_sqrt * v_norm);
let term1 = x.mapv(|val| val * cosh_term);
let term2 = v.mapv(|val| val * sinh_term);
term1 + term2
}
pub fn log_map(&self, x: &ArrayView1<T>, y: &ArrayView1<T>) -> Array1<T> {
let inner = self.minkowski_dot(x, y);
let d = self.distance(x, y);
let epsilon = T::from_f64(1e-15).unwrap();
if d < epsilon {
return Array1::zeros(x.len());
}
let term_x = x.mapv(|val| val * self.c * inner);
let v = y.to_owned() + term_x;
let v_norm_sq = self.minkowski_dot(&v.view(), &v.view());
if v_norm_sq < epsilon {
return Array1::zeros(x.len());
}
let v_norm = v_norm_sq.sqrt();
v * (d / v_norm)
}
pub fn parallel_transport(
&self,
x: &ArrayView1<T>,
y: &ArrayView1<T>,
v: &ArrayView1<T>,
) -> Array1<T> {
let inner_xy = self.minkowski_dot(x, y);
let inner_vy = self.minkowski_dot(v, y);
let one = T::one();
let denom = one - self.c * inner_xy;
let epsilon = T::from_f64(1e-15).unwrap();
if denom.abs() < epsilon {
return v.to_owned();
}
let coeff = self.c * inner_vy / denom;
let sum_xy = x.to_owned() + y;
v.to_owned() - sum_xy.mapv(|val| val * coeff)
}
pub fn origin(&self, dim: usize) -> Array1<T> {
let mut o = Array1::zeros(dim + 1);
o[0] = T::one() / self.c.sqrt();
o
}
}
pub mod conversions {
use super::*;
use crate::PoincareBall;
pub fn poincare_to_lorentz<T>(ball: &PoincareBall<T>, x: &ArrayView1<T>) -> Array1<T>
where T: Float + FromPrimitive + Zero + ndarray::ScalarOperand
{
let c = ball.c;
let x_norm_sq = x.dot(x);
let one = T::one();
let two = T::from_f64(2.0).unwrap();
let denom = one - c * x_norm_sq;
let c_sqrt = c.sqrt();
let t = (one + c * x_norm_sq) / (denom * c_sqrt);
let mut result = Array1::zeros(x.len() + 1);
result[0] = t;
let scale = two / (denom * c_sqrt);
for (i, &val) in x.iter().enumerate() {
result[i + 1] = val * scale;
}
result
}
pub fn lorentz_to_poincare<T>(lorentz: &LorentzModel<T>, x: &ArrayView1<T>) -> Array1<T>
where T: Float + FromPrimitive + Zero + ndarray::ScalarOperand
{
let c_sqrt = lorentz.c.sqrt();
let one = T::one();
let denom = x[0] * c_sqrt + one;
let mut result = Array1::zeros(x.len() - 1);
for i in 1..x.len() {
result[i - 1] = x[i] / denom;
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use ndarray::array;
const TOL: f64 = 1e-10;
#[test]
fn test_origin_on_manifold() {
let lorentz = LorentzModel::new(1.0);
let o = lorentz.origin(3);
assert!(lorentz.is_on_manifold(&o.view(), TOL));
}
#[test]
fn test_from_euclidean_on_manifold() {
let lorentz = LorentzModel::new(1.0);
let v = array![0.5, -0.3, 0.2];
let x = lorentz.from_euclidean(&v.view());
assert!(lorentz.is_on_manifold(&x.view(), TOL));
}
#[test]
fn test_distance_self_zero() {
let lorentz = LorentzModel::new(1.0);
let x = lorentz.from_euclidean(&array![0.5, 0.3].view());
let d = lorentz.distance(&x.view(), &x.view());
assert!(d.abs() < TOL);
}
#[test]
fn test_distance_symmetric() {
let lorentz = LorentzModel::new(1.0);
let x = lorentz.from_euclidean(&array![0.5, 0.3].view());
let y = lorentz.from_euclidean(&array![-0.2, 0.4].view());
let d_xy = lorentz.distance(&x.view(), &y.view());
let d_yx = lorentz.distance(&y.view(), &x.view());
assert_relative_eq!(d_xy, d_yx, epsilon = TOL);
}
#[test]
fn test_distance_triangle_inequality() {
let lorentz = LorentzModel::new(1.0);
let a = lorentz.from_euclidean(&array![0.3, 0.0].view());
let b = lorentz.from_euclidean(&array![0.0, 0.3].view());
let c = lorentz.from_euclidean(&array![-0.3, 0.0].view());
let d_ac = lorentz.distance(&a.view(), &c.view());
let d_ab = lorentz.distance(&a.view(), &b.view());
let d_bc = lorentz.distance(&b.view(), &c.view());
assert!(d_ac <= d_ab + d_bc + TOL);
}
#[test]
fn test_exp_log_inverse() {
let lorentz = LorentzModel::new(1.0);
let x = lorentz.origin(3);
let v = array![0.0, 0.3, 0.2, 0.1];
let y = lorentz.exp_map(&x.view(), &v.view());
assert!(lorentz.is_on_manifold(&y.view(), 1e-6));
let v_recovered = lorentz.log_map(&x.view(), &y.view());
for i in 0..v.len() {
assert_relative_eq!(v[i], v_recovered[i], epsilon = 1e-6);
}
}
#[test]
fn test_project_preserves_direction() {
let lorentz = LorentzModel::new(1.0);
let x = array![2.0, 0.5, 0.3];
let projected = lorentz.project(&x.view());
assert!(lorentz.is_on_manifold(&projected.view(), TOL));
assert_relative_eq!(x[1] / x[2], projected[1] / projected[2], epsilon = TOL);
}
#[test]
fn test_conversion_round_trip() {
use conversions::*;
let ball = crate::PoincareBall::new(1.0);
let lorentz = LorentzModel::new(1.0);
let p = array![0.3, 0.2, -0.1];
let l = poincare_to_lorentz(&ball, &p.view());
let p_back = lorentz_to_poincare(&lorentz, &l.view());
for i in 0..p.len() {
assert_relative_eq!(p[i], p_back[i], epsilon = 1e-10);
}
}
#[test]
fn test_different_curvatures() {
let l1 = LorentzModel::new(1.0);
let l2 = LorentzModel::new(4.0);
let x1 = l1.from_euclidean(&array![0.5, 0.3].view());
let y1 = l1.from_euclidean(&array![-0.2, 0.4].view());
let x2 = l2.from_euclidean(&array![0.5, 0.3].view());
let y2 = l2.from_euclidean(&array![-0.2, 0.4].view());
let d1 = l1.distance(&x1.view(), &y1.view());
let d2 = l2.distance(&x2.view(), &y2.view());
assert!((d1 - d2).abs() > 1e-6);
}
}