use num_traits::{Float, FromPrimitive};
#[derive(Debug, Clone, Copy)]
pub struct PoincareBallCore<T> {
pub c: T,
}
impl<T> PoincareBallCore<T>
where
T: Float + FromPrimitive,
{
pub fn new(c: T) -> Self {
assert!(c > T::zero(), "curvature must be positive");
Self { c }
}
pub fn mobius_add(&self, x: &[T], y: &[T]) -> Vec<T> {
assert_eq!(x.len(), y.len());
let c = self.c;
let one = T::one();
let two = T::from_f64(2.0).unwrap();
let x_norm_sq = dot(x, x);
let y_norm_sq = dot(y, y);
let xy = dot(x, y);
let denom = one + two * c * xy + c * c * x_norm_sq * y_norm_sq;
let s1 = one + two * c * xy + c * y_norm_sq;
let s2 = one - c * x_norm_sq;
let mut out = vec![T::zero(); x.len()];
for i in 0..x.len() {
out[i] = (x[i] * s1 + y[i] * s2) / denom;
}
out
}
pub fn distance(&self, x: &[T], y: &[T]) -> T {
assert_eq!(x.len(), y.len());
let mut neg_x = vec![T::zero(); x.len()];
for i in 0..x.len() {
neg_x[i] = -x[i];
}
let diff = self.mobius_add(&neg_x, y);
let diff_norm = dot(&diff, &diff).sqrt();
let c_sqrt = self.c.sqrt();
let two = T::from_f64(2.0).unwrap();
two / c_sqrt * (c_sqrt * diff_norm).atanh()
}
pub fn log_map_zero(&self, y: &[T]) -> Vec<T> {
let y_norm = dot(y, y).sqrt();
let eps = T::from_f64(1e-7).unwrap();
if y_norm < eps {
return y.to_vec();
}
let c_sqrt = self.c.sqrt();
let scale = (c_sqrt * y_norm).atanh() / (c_sqrt * y_norm);
let mut out = vec![T::zero(); y.len()];
for i in 0..y.len() {
out[i] = y[i] * scale;
}
out
}
pub fn exp_map_zero(&self, v: &[T]) -> Vec<T> {
let v_norm = dot(v, v).sqrt();
let eps = T::from_f64(1e-7).unwrap();
if v_norm < eps {
return v.to_vec();
}
let c_sqrt = self.c.sqrt();
let scale = (c_sqrt * v_norm).tanh() / (c_sqrt * v_norm);
let mut out = vec![T::zero(); v.len()];
for i in 0..v.len() {
out[i] = v[i] * scale;
}
out
}
pub fn is_in_ball(&self, x: &[T]) -> bool {
let norm_sq = dot(x, x);
norm_sq < T::one() / self.c
}
pub fn project(&self, x: &[T]) -> Vec<T> {
let norm = dot(x, x).sqrt();
let one = T::one();
let eps = T::from_f64(1e-5).unwrap();
let max_norm = (one / self.c).sqrt() - eps;
if norm > max_norm {
let scale = max_norm / norm;
let mut out = vec![T::zero(); x.len()];
for i in 0..x.len() {
out[i] = x[i] * scale;
}
out
} else {
x.to_vec()
}
}
}
fn dot<T: Float>(a: &[T], b: &[T]) -> T {
assert_eq!(a.len(), b.len());
let mut acc = T::zero();
for i in 0..a.len() {
acc = acc + a[i] * b[i];
}
acc
}
#[derive(Debug, Clone, Copy)]
pub struct LorentzModelCore<T> {
pub c: T,
}
impl<T> LorentzModelCore<T>
where
T: Float + FromPrimitive,
{
pub fn new(c: T) -> Self {
assert!(c > T::zero(), "curvature must be positive");
Self { c }
}
pub fn minkowski_dot(&self, x: &[T], y: &[T]) -> T {
assert_eq!(x.len(), y.len());
assert!(x.len() >= 2);
let mut acc = -x[0] * y[0];
for i in 1..x.len() {
acc = acc + x[i] * y[i];
}
acc
}
pub fn distance(&self, x: &[T], y: &[T]) -> T {
let inner = self.minkowski_dot(x, y);
let arg = -self.c * inner;
let one = T::one();
let two = T::from_f64(2.0).unwrap();
let eps = T::from_f64(1e-7).unwrap();
if arg < one + eps {
if arg <= one {
return T::zero();
}
let x = arg - one;
return (two * x).sqrt() / self.c.sqrt();
}
arg.acosh() / self.c.sqrt()
}
pub fn is_on_manifold(&self, x: &[T], tol: T) -> bool {
let inner = self.minkowski_dot(x, x);
(inner + T::one() / self.c).abs() < tol
}
pub fn project(&self, x: &[T]) -> Vec<T> {
assert!(x.len() >= 2);
let mut space_norm_sq = T::zero();
for i in 1..x.len() {
space_norm_sq = space_norm_sq + x[i] * x[i];
}
let t = (space_norm_sq + T::one() / self.c).sqrt();
let mut out = x.to_vec();
out[0] = t;
out
}
pub fn from_euclidean(&self, v: &[T]) -> Vec<T> {
let mut space_norm_sq = T::zero();
for &val in v {
space_norm_sq = space_norm_sq + val * val;
}
let t = (space_norm_sq + T::one() / self.c).sqrt();
let mut out = vec![T::zero(); v.len() + 1];
out[0] = t;
for i in 0..v.len() {
out[i + 1] = v[i];
}
out
}
pub fn to_euclidean(&self, x: &[T]) -> Vec<T> {
assert!(x.len() >= 2);
x[1..].to_vec()
}
pub fn exp_map(&self, x: &[T], v: &[T]) -> Vec<T> {
assert_eq!(x.len(), v.len());
let v_norm_sq = self.minkowski_dot(v, v);
let eps = T::from_f64(1e-15).unwrap();
if v_norm_sq < eps {
return x.to_vec();
}
let v_norm = v_norm_sq.sqrt();
let c_sqrt = self.c.sqrt();
let s = c_sqrt * v_norm;
let cosh_term = s.cosh();
let sinh_term = s.sinh() / s;
let mut out = vec![T::zero(); x.len()];
for i in 0..x.len() {
out[i] = x[i] * cosh_term + v[i] * sinh_term;
}
out
}
pub fn log_map(&self, x: &[T], y: &[T]) -> Vec<T> {
assert_eq!(x.len(), y.len());
let inner = self.minkowski_dot(x, y);
let d = self.distance(x, y);
let eps = T::from_f64(1e-15).unwrap();
if d < eps {
return vec![T::zero(); x.len()];
}
let mut v = vec![T::zero(); x.len()];
for i in 0..x.len() {
v[i] = y[i] + x[i] * (self.c * inner);
}
let v_norm_sq = self.minkowski_dot(&v, &v);
if v_norm_sq < eps {
return vec![T::zero(); x.len()];
}
let v_norm = v_norm_sq.sqrt();
let scale = d / v_norm;
for i in 0..v.len() {
v[i] = v[i] * scale;
}
v
}
pub fn origin(&self, space_dim: usize) -> Vec<T> {
let mut out = vec![T::zero(); space_dim + 1];
out[0] = T::one() / self.c.sqrt();
out
}
}
pub mod conversions {
use super::{Float, FromPrimitive, LorentzModelCore, PoincareBallCore};
pub fn poincare_to_lorentz<T>(ball: &PoincareBallCore<T>, x: &[T]) -> Vec<T>
where
T: Float + FromPrimitive,
{
let c = ball.c;
let one = T::one();
let two = T::from_f64(2.0).unwrap();
let mut x_norm_sq = T::zero();
for &v in x {
x_norm_sq = x_norm_sq + v * v;
}
let denom = one - c * x_norm_sq;
let c_sqrt = c.sqrt();
let t = (one + c * x_norm_sq) / (denom * c_sqrt);
let scale = two / (denom * c_sqrt);
let mut out = vec![T::zero(); x.len() + 1];
out[0] = t;
for i in 0..x.len() {
out[i + 1] = x[i] * scale;
}
out
}
pub fn lorentz_to_poincare<T>(lorentz: &LorentzModelCore<T>, x: &[T]) -> Vec<T>
where
T: Float + FromPrimitive,
{
assert!(x.len() >= 2);
let c_sqrt = lorentz.c.sqrt();
let one = T::one();
let denom = x[0] * c_sqrt + one;
let mut out = vec![T::zero(); x.len() - 1];
for i in 1..x.len() {
out[i - 1] = x[i] / denom;
}
out
}
}
pub mod diagnostics {
pub fn ultrametric_max_violation_f64(dist: &[f64], n: usize) -> f64 {
assert_eq!(dist.len(), n * n);
let mut max_v = 0.0f64;
for i in 0..n {
for j in 0..n {
if i == j {
continue;
}
let dij = dist[i * n + j];
for k in 0..n {
if k == i || k == j {
continue;
}
let dik = dist[i * n + k];
let djk = dist[j * n + k];
let rhs = dij.max(djk);
let v = dik - rhs;
if v > max_v {
max_v = v;
}
}
}
}
if max_v.is_sign_negative() {
0.0
} else {
max_v
}
}
pub fn delta_hyperbolicity_four_point_exact_f64(dist: &[f64], n: usize) -> f64 {
assert_eq!(dist.len(), n * n);
let mut max_delta = 0.0f64;
for a in 0..n {
for b in 0..n {
if b == a {
continue;
}
for c in 0..n {
if c == a || c == b {
continue;
}
for d in 0..n {
if d == a || d == b || d == c {
continue;
}
let s1 = dist[a * n + b] + dist[c * n + d];
let s2 = dist[a * n + c] + dist[b * n + d];
let s3 = dist[a * n + d] + dist[b * n + c];
let mut m1 = s1;
let mut m2 = s2;
let mut m3 = s3;
if m1 < m2 {
std::mem::swap(&mut m1, &mut m2);
}
if m2 < m3 {
std::mem::swap(&mut m2, &mut m3);
}
if m1 < m2 {
std::mem::swap(&mut m1, &mut m2);
}
let delta = 0.5 * (m1 - m2);
if delta > max_delta {
max_delta = delta;
}
}
}
}
}
max_delta
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() <= tol
}
#[test]
fn poincare_core_mobius_identity() {
let ball = PoincareBallCore::new(1.0f64);
let x = [0.2, -0.1, 0.05];
let zero = [0.0, 0.0, 0.0];
let r = ball.mobius_add(&x, &zero);
for i in 0..x.len() {
assert!(approx_eq(r[i], x[i], 1e-12));
}
}
#[test]
fn poincare_core_distance_self_zero() {
let ball = PoincareBallCore::new(1.0f64);
let x = [0.3, 0.1];
let d = ball.distance(&x, &x);
assert!(d.abs() < 1e-12);
}
#[test]
fn poincare_core_exp_log_round_trip_small() {
let ball = PoincareBallCore::new(1.0f64);
let v = [0.3, -0.2];
let y = ball.exp_map_zero(&v);
let v2 = ball.log_map_zero(&y);
for i in 0..v.len() {
assert!(approx_eq(v2[i], v[i], 1e-7));
}
}
#[test]
fn lorentz_core_origin_is_on_manifold() {
let lorentz = LorentzModelCore::new(1.0f64);
let o = lorentz.origin(3);
assert!(lorentz.is_on_manifold(&o, 1e-12));
}
#[test]
fn conversion_core_round_trip_poincare() {
let ball = PoincareBallCore::new(1.0f64);
let lorentz = LorentzModelCore::new(1.0f64);
let x = [0.2, 0.1, -0.05];
assert!(ball.is_in_ball(&x));
let xl = conversions::poincare_to_lorentz(&ball, &x);
assert!(lorentz.is_on_manifold(&xl, 1e-10));
let x2 = conversions::lorentz_to_poincare(&lorentz, &xl);
for i in 0..x.len() {
assert!(approx_eq(x2[i], x[i], 1e-10));
}
}
#[test]
fn ultrametric_violation_zero_for_simple_ultrametric() {
let n = 4usize;
let mut d = vec![0.0f64; n * n];
let set = |d: &mut [f64], i: usize, j: usize, v: f64| {
d[i * n + j] = v;
d[j * n + i] = v;
};
set(&mut d, 0, 1, 1.0);
set(&mut d, 2, 3, 1.0);
for (i, j) in [(0, 2), (0, 3), (1, 2), (1, 3)] {
set(&mut d, i, j, 2.0);
}
let v = diagnostics::ultrametric_max_violation_f64(&d, n);
assert!(v.abs() < 1e-12, "expected 0, got {v}");
}
#[test]
fn ultrametric_violation_positive_for_non_ultrametric() {
let n = 3usize;
let d = vec![
0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, ];
let v = diagnostics::ultrametric_max_violation_f64(&d, n);
assert!((v - 2.0).abs() < 1e-12, "expected 2, got {v}");
}
#[test]
fn delta_hyperbolicity_c4_is_one() {
let n = 4usize;
let d = vec![
0.0, 1.0, 2.0, 1.0, 1.0, 0.0, 1.0, 2.0, 2.0, 1.0, 0.0, 1.0, 1.0, 2.0, 1.0, 0.0, ];
let delta = diagnostics::delta_hyperbolicity_four_point_exact_f64(&d, n);
assert!((delta - 1.0).abs() < 1e-12, "expected 1, got {delta}");
}
}