pub mod core;
#[cfg(feature = "ndarray")]
use ndarray::{Array1, ArrayView1};
#[cfg(feature = "ndarray")]
use num_traits::{Float, FromPrimitive, Zero};
#[cfg(feature = "ndarray")]
pub mod lorentz;
#[cfg(feature = "ndarray")]
pub use lorentz::LorentzModel;
#[cfg(feature = "ndarray")]
pub struct PoincareBall<T> {
pub c: T,
}
#[cfg(feature = "ndarray")]
impl<T> PoincareBall<T>
where
T: Float + FromPrimitive + Zero + ndarray::ScalarOperand + ndarray::LinalgScalar,
{
pub fn new(c: T) -> Self {
Self { c }
}
pub fn mobius_add(&self, x: &ArrayView1<T>, y: &ArrayView1<T>) -> Array1<T> {
let x_norm_sq = x.dot(x);
let y_norm_sq = y.dot(y);
let xy = x.dot(y);
let c = self.c;
let one = T::one();
let two = T::from_f64(2.0).unwrap();
let denom = one + two * c * xy + c * c * x_norm_sq * y_norm_sq;
let s1 = one + two * c * xy + c * y_norm_sq;
let s2 = one - c * x_norm_sq;
let term1 = x.to_owned() * s1;
let term2 = y.to_owned() * s2;
(term1 + term2) / denom
}
pub fn distance(&self, x: &ArrayView1<T>, y: &ArrayView1<T>) -> T {
let neg_x = x.mapv(|v| -v);
let neg_x_view = neg_x.view();
let diff = self.mobius_add(&neg_x_view, y);
let diff_norm = diff.dot(&diff).sqrt();
let c_sqrt = self.c.sqrt();
let two = T::from_f64(2.0).unwrap();
two / c_sqrt * (c_sqrt * diff_norm).atanh()
}
pub fn log_map_zero(&self, y: &ArrayView1<T>) -> Array1<T> {
let y_norm = y.dot(y).sqrt();
let epsilon = T::from_f64(1e-7).unwrap();
if y_norm < epsilon {
return y.to_owned();
}
let c_sqrt = self.c.sqrt();
let scale = (c_sqrt * y_norm).atanh() / (c_sqrt * y_norm);
y * scale
}
pub fn exp_map_zero(&self, v: &ArrayView1<T>) -> Array1<T> {
let v_norm = v.dot(v).sqrt();
let epsilon = T::from_f64(1e-7).unwrap();
if v_norm < epsilon {
return v.to_owned();
}
let c_sqrt = self.c.sqrt();
let scale = (c_sqrt * v_norm).tanh() / (c_sqrt * v_norm);
v * scale
}
pub fn is_in_ball(&self, x: &ArrayView1<T>) -> bool {
let norm_sq = x.dot(x);
norm_sq < T::one() / self.c
}
pub fn project(&self, x: &ArrayView1<T>) -> Array1<T> {
let norm = x.dot(x).sqrt();
let one = T::one();
let epsilon = T::from_f64(1e-5).unwrap();
let max_norm = (one / self.c).sqrt() - epsilon;
if norm > max_norm {
x * (max_norm / norm)
} else {
x.to_owned()
}
}
}
#[cfg(all(test, feature = "ndarray"))]
mod tests {
use super::*;
use ndarray::array;
const EPS: f64 = 1e-10;
#[test]
fn test_distance_self_is_zero() {
let ball = PoincareBall::new(1.0);
let x = array![0.1, 0.2, 0.3];
let d = ball.distance(&x.view(), &x.view());
assert!(d.abs() < EPS, "distance to self should be 0, got {}", d);
}
#[test]
fn test_distance_symmetric() {
let ball = PoincareBall::new(1.0);
let x = array![0.1, 0.2];
let y = array![0.3, -0.1];
let d_xy = ball.distance(&x.view(), &y.view());
let d_yx = ball.distance(&y.view(), &x.view());
assert!((d_xy - d_yx).abs() < EPS, "distance not symmetric");
}
#[test]
fn test_distance_non_negative() {
let ball = PoincareBall::new(1.0);
let x = array![0.1, 0.2];
let y = array![0.3, -0.1];
let d = ball.distance(&x.view(), &y.view());
assert!(d >= 0.0, "distance should be non-negative");
}
#[test]
fn test_distance_triangle_inequality() {
let ball = PoincareBall::new(1.0);
let a = array![0.1, 0.0];
let b = array![0.0, 0.1];
let c = array![-0.1, 0.0];
let d_ac = ball.distance(&a.view(), &c.view());
let d_ab = ball.distance(&a.view(), &b.view());
let d_bc = ball.distance(&b.view(), &c.view());
assert!(
d_ac <= d_ab + d_bc + EPS,
"triangle inequality violated: {} > {} + {}",
d_ac,
d_ab,
d_bc
);
}
#[test]
fn test_mobius_add_identity() {
let ball = PoincareBall::new(1.0);
let x = array![0.1, 0.2, 0.3];
let zero = array![0.0, 0.0, 0.0];
let result = ball.mobius_add(&x.view(), &zero.view());
for i in 0..3 {
assert!(
(result[i] - x[i]).abs() < EPS,
"mobius_add with zero failed at index {}",
i
);
}
}
#[test]
fn test_exp_log_round_trip() {
let ball = PoincareBall::new(1.0);
let v = array![0.3, 0.2, 0.1];
let on_manifold = ball.exp_map_zero(&v.view());
let recovered = ball.log_map_zero(&on_manifold.view());
for i in 0..3 {
assert!(
(recovered[i] - v[i]).abs() < 1e-6,
"exp/log round trip failed at index {}: {} vs {}",
i,
recovered[i],
v[i]
);
}
}
#[test]
fn test_exp_map_stays_in_ball() {
let ball = PoincareBall::new(1.0);
let large_v = array![10.0, 10.0, 10.0];
let result = ball.exp_map_zero(&large_v.view());
assert!(
ball.is_in_ball(&result.view()),
"exp_map result escaped the ball"
);
}
#[test]
fn test_project_inside_unchanged() {
let ball = PoincareBall::new(1.0);
let inside = array![0.1, 0.2]; let projected = ball.project(&inside.view());
for i in 0..2 {
assert!(
(projected[i] - inside[i]).abs() < EPS,
"projection changed point already inside ball"
);
}
}
#[test]
fn test_project_outside_onto_boundary() {
let ball = PoincareBall::new(1.0);
let outside = array![2.0, 0.0]; let projected = ball.project(&outside.view());
assert!(
ball.is_in_ball(&projected.view()),
"projection did not bring point inside ball"
);
}
#[test]
fn test_curvature_affects_distance() {
let ball_c1 = PoincareBall::new(1.0);
let ball_c2 = PoincareBall::new(4.0);
let x = array![0.1, 0.0];
let y = array![0.0, 0.1];
let d1 = ball_c1.distance(&x.view(), &y.view());
let d2 = ball_c2.distance(&x.view(), &y.view());
assert!(
(d1 - d2).abs() > 1e-6,
"curvature should affect distance: c=1 gives {}, c=4 gives {}",
d1,
d2
);
}
}