use crate::hyperbolic::{EPSILON, MAX_NORM};
use simsimd::SpatialSimilarity;
pub struct PoincareBall {
pub curvature: f32,
}
impl PoincareBall {
pub fn new(curvature: f32) -> Self {
assert!(curvature < 0.0, "Curvature must be negative");
Self { curvature }
}
#[inline]
fn norm_squared(&self, x: &[f32]) -> f32 {
(f32::dot(x, x).unwrap_or(0.0) as f32).max(0.0)
}
#[inline]
fn norm(&self, x: &[f32]) -> f32 {
self.norm_squared(x).sqrt()
}
pub fn project(&self, x: &[f32]) -> Vec<f32> {
let norm = self.norm(x);
if norm < MAX_NORM {
x.to_vec()
} else {
let scale = MAX_NORM / (norm + EPSILON);
x.iter().map(|&v| v * scale).collect()
}
}
pub fn distance(&self, x: &[f32], y: &[f32]) -> f32 {
assert_eq!(x.len(), y.len(), "Vectors must have same dimension");
let x_norm_sq = self.norm_squared(x);
let y_norm_sq = self.norm_squared(y);
let diff: Vec<f32> = x.iter().zip(y.iter()).map(|(&a, &b)| a - b).collect();
let diff_norm_sq = self.norm_squared(&diff);
let x_factor = 1.0 - x_norm_sq;
let y_factor = 1.0 - y_norm_sq;
if x_factor <= EPSILON || y_factor <= EPSILON {
return f32::INFINITY;
}
let numerator = 2.0 * diff_norm_sq;
let denominator = x_factor * y_factor;
let ratio = numerator / (denominator + EPSILON);
let arg = 1.0 + ratio;
let distance = arg.acosh();
let k = self.curvature.abs().sqrt();
distance / k
}
pub fn mobius_add(&self, x: &[f32], y: &[f32]) -> Vec<f32> {
assert_eq!(x.len(), y.len(), "Vectors must have same dimension");
let x_norm_sq = self.norm_squared(x);
let y_norm_sq = self.norm_squared(y);
let xy_dot = f32::dot(x, y).unwrap_or(0.0) as f32;
let numerator_x_coeff = 1.0f32 + 2.0f32 * xy_dot + y_norm_sq;
let numerator_y_coeff = 1.0f32 - x_norm_sq;
let denominator = 1.0f32 + 2.0f32 * xy_dot + x_norm_sq * y_norm_sq + EPSILON;
let result: Vec<f32> = x
.iter()
.zip(y.iter())
.map(|(&xi, &yi)| (numerator_x_coeff * xi + numerator_y_coeff * yi) / denominator)
.collect();
self.project(&result)
}
pub fn exp_map(&self, base: &[f32], tangent: &[f32]) -> Vec<f32> {
assert_eq!(
base.len(),
tangent.len(),
"Vectors must have same dimension"
);
let tangent_norm = self.norm(tangent);
if tangent_norm < EPSILON {
return base.to_vec();
}
let k = self.curvature.abs().sqrt();
let lambda_base = 2.0 / (1.0 - self.norm_squared(base) + EPSILON);
let coeff = (k * lambda_base * tangent_norm / 2.0).tanh() / (k * tangent_norm + EPSILON);
let scaled_tangent: Vec<f32> = tangent.iter().map(|&v| v * coeff).collect();
self.mobius_add(base, &scaled_tangent)
}
pub fn log_map(&self, base: &[f32], target: &[f32]) -> Vec<f32> {
assert_eq!(base.len(), target.len(), "Vectors must have same dimension");
let neg_base: Vec<f32> = base.iter().map(|&v| -v).collect();
let diff = self.mobius_add(&neg_base, target);
let diff_norm = self.norm(&diff);
if diff_norm < EPSILON {
return vec![0.0; base.len()];
}
let k = self.curvature.abs().sqrt();
let lambda_base = 2.0 / (1.0 - self.norm_squared(base) + EPSILON);
let coeff =
2.0 / (k * lambda_base + EPSILON) * (k * diff_norm).atanh() / (diff_norm + EPSILON);
diff.iter().map(|&v| v * coeff).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f32 = 1e-4;
#[test]
fn test_poincare_ball_creation() {
let ball = PoincareBall::new(-1.0);
assert_eq!(ball.curvature, -1.0);
}
#[test]
#[should_panic(expected = "Curvature must be negative")]
fn test_poincare_positive_curvature_panics() {
let _ball = PoincareBall::new(1.0);
}
#[test]
fn test_project_within_ball() {
let ball = PoincareBall::new(-1.0);
let x = vec![0.5, 0.5];
let projected = ball.project(&x);
assert_eq!(projected, x);
}
#[test]
fn test_project_outside_ball() {
let ball = PoincareBall::new(-1.0);
let x = vec![1.5, 1.5]; let projected = ball.project(&x);
let norm = ball.norm(&projected);
assert!(norm <= MAX_NORM);
}
#[test]
fn test_distance_origin() {
let ball = PoincareBall::new(-1.0);
let origin = vec![0.0, 0.0];
let point = vec![0.5, 0.0];
let dist = ball.distance(&origin, &point);
assert!(dist > 0.0);
assert!(dist < f32::INFINITY);
}
#[test]
fn test_distance_symmetric() {
let ball = PoincareBall::new(-1.0);
let x = vec![0.3, 0.4];
let y = vec![0.1, 0.2];
let d1 = ball.distance(&x, &y);
let d2 = ball.distance(&y, &x);
assert!((d1 - d2).abs() < TOL);
}
#[test]
fn test_distance_same_point() {
let ball = PoincareBall::new(-1.0);
let x = vec![0.3, 0.4];
let dist = ball.distance(&x, &x);
assert!(dist < TOL);
}
#[test]
fn test_mobius_add_identity() {
let ball = PoincareBall::new(-1.0);
let x = vec![0.3, 0.4];
let origin = vec![0.0, 0.0];
let result = ball.mobius_add(&x, &origin);
for i in 0..x.len() {
assert!((result[i] - x[i]).abs() < TOL);
}
}
#[test]
fn test_exp_map_zero_tangent() {
let ball = PoincareBall::new(-1.0);
let base = vec![0.3, 0.4];
let tangent = vec![0.0, 0.0];
let result = ball.exp_map(&base, &tangent);
assert_eq!(result, base);
}
#[test]
fn test_log_exp_inverse() {
let ball = PoincareBall::new(-1.0);
let base = vec![0.2, 0.3];
let tangent = vec![0.1, 0.1];
let point = ball.exp_map(&base, &tangent);
let recovered = ball.log_map(&base, &point);
for i in 0..tangent.len() {
assert!((recovered[i] - tangent[i]).abs() < TOL);
}
}
#[test]
fn test_log_map_same_point() {
let ball = PoincareBall::new(-1.0);
let base = vec![0.3, 0.4];
let result = ball.log_map(&base, &base);
for &v in &result {
assert!(v.abs() < TOL);
}
}
#[test]
fn test_curvature_scaling() {
let ball1 = PoincareBall::new(-1.0);
let ball2 = PoincareBall::new(-4.0);
let x = vec![0.3, 0.4];
let y = vec![0.1, 0.2];
let d1 = ball1.distance(&x, &y);
let d2 = ball2.distance(&x, &y);
assert!(d2 < d1);
}
}