use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TangentSpaceConfig {
pub hyperbolic_dim: usize,
pub curvature: f32,
pub learnable_origin: bool,
}
impl Default for TangentSpaceConfig {
fn default() -> Self {
Self {
hyperbolic_dim: 32,
curvature: -1.0,
learnable_origin: true,
}
}
}
#[derive(Debug, Clone)]
pub struct TangentSpaceMapper {
config: TangentSpaceConfig,
origin: Vec<f32>,
lambda_origin: f32,
}
impl TangentSpaceMapper {
pub fn new(config: TangentSpaceConfig) -> Self {
let origin = vec![0.0f32; config.hyperbolic_dim];
let c = -config.curvature;
let origin_norm_sq: f32 = origin.iter().map(|x| x * x).sum();
let lambda_origin = 2.0 / (1.0 - c * origin_norm_sq).max(1e-8);
Self {
config,
origin,
lambda_origin,
}
}
pub fn set_origin(&mut self, origin: Vec<f32>) {
let c = -self.config.curvature;
let origin_norm_sq: f32 = origin.iter().map(|x| x * x).sum();
self.lambda_origin = 2.0 / (1.0 - c * origin_norm_sq).max(1e-8);
self.origin = origin;
}
#[inline]
pub fn log_map(&self, point: &[f32]) -> Vec<f32> {
let c = -self.config.curvature;
let sqrt_c = c.sqrt();
if self.origin.iter().all(|&x| x.abs() < 1e-8) {
return self.log_map_at_origin(point, sqrt_c);
}
let neg_origin: Vec<f32> = self.origin.iter().map(|x| -x).collect();
let diff = self.mobius_add(&neg_origin, point, c);
let diff_norm: f32 = diff.iter().map(|x| x * x).sum::<f32>().sqrt();
if diff_norm < 1e-8 {
return vec![0.0f32; point.len()];
}
let scale =
(2.0 / self.lambda_origin) * (sqrt_c * diff_norm).atanh() / (sqrt_c * diff_norm);
diff.iter().map(|&d| scale * d).collect()
}
#[inline]
fn log_map_at_origin(&self, point: &[f32], sqrt_c: f32) -> Vec<f32> {
let norm: f32 = point.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-8 {
return vec![0.0f32; point.len()];
}
let arg = (sqrt_c * norm).min(0.99);
let scale = 2.0 * arg.atanh() / (sqrt_c * norm);
point.iter().map(|&p| scale * p).collect()
}
fn mobius_add(&self, x: &[f32], y: &[f32], c: f32) -> Vec<f32> {
let x_norm_sq: f32 = x.iter().map(|xi| xi * xi).sum();
let y_norm_sq: f32 = y.iter().map(|yi| yi * yi).sum();
let xy_dot: f32 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
let num_coef = 1.0 + 2.0 * c * xy_dot + c * y_norm_sq;
let denom = 1.0 + 2.0 * c * xy_dot + c * c * x_norm_sq * y_norm_sq;
if denom.abs() < 1e-8 {
return x.to_vec();
}
let y_coef = 1.0 - c * x_norm_sq;
x.iter()
.zip(y.iter())
.map(|(&xi, &yi)| (num_coef * xi + y_coef * yi) / denom)
.collect()
}
#[inline]
pub fn tangent_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
let ta = self.log_map(a);
let tb = self.log_map(b);
ta.iter().zip(tb.iter()).map(|(&ai, &bi)| ai * bi).sum()
}
pub fn batch_log_map(&self, points: &[&[f32]]) -> Vec<Vec<f32>> {
points.iter().map(|p| self.log_map(p)).collect()
}
pub fn batch_tangent_similarity(
&self,
query_tangent: &[f32],
keys_tangent: &[&[f32]],
) -> Vec<f32> {
keys_tangent
.iter()
.map(|k| Self::dot_product_simd(query_tangent, k))
.collect()
}
#[inline(always)]
fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
let len = a.len().min(b.len());
let chunks = len / 4;
let remainder = len % 4;
let mut sum0 = 0.0f32;
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
let mut sum3 = 0.0f32;
for i in 0..chunks {
let base = i * 4;
sum0 += a[base] * b[base];
sum1 += a[base + 1] * b[base + 1];
sum2 += a[base + 2] * b[base + 2];
sum3 += a[base + 3] * b[base + 3];
}
let base = chunks * 4;
for i in 0..remainder {
sum0 += a[base + i] * b[base + i];
}
sum0 + sum1 + sum2 + sum3
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_map_at_origin() {
let config = TangentSpaceConfig {
hyperbolic_dim: 4,
curvature: -1.0,
learnable_origin: false,
};
let mapper = TangentSpaceMapper::new(config);
let origin = vec![0.0f32; 4];
let result = mapper.log_map(&origin);
assert!(result.iter().all(|&x| x.abs() < 1e-6));
let point = vec![0.1, 0.2, 0.0, 0.0];
let tangent = mapper.log_map(&point);
assert_eq!(tangent.len(), 4);
}
#[test]
fn test_tangent_similarity() {
let config = TangentSpaceConfig {
hyperbolic_dim: 4,
curvature: -1.0,
learnable_origin: false,
};
let mapper = TangentSpaceMapper::new(config);
let a = vec![0.1, 0.1, 0.0, 0.0];
let b = vec![0.1, 0.1, 0.0, 0.0];
let sim = mapper.tangent_similarity(&a, &b);
assert!(sim > 0.0);
}
#[test]
fn test_batch_operations() {
let config = TangentSpaceConfig::default();
let mapper = TangentSpaceMapper::new(config);
let points: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.05; 32]).collect();
let points_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect();
let tangents = mapper.batch_log_map(&points_refs);
assert_eq!(tangents.len(), 10);
}
}