const EPS: f32 = 1e-7;
#[inline]
pub fn lorentz_inner(x: &[f32], y: &[f32]) -> f32 {
debug_assert!(x.len() == y.len());
if x.len() < 2 {
return 0.0;
}
let time = -x[0] * y[0];
let space: f32 = x[1..].iter().zip(&y[1..]).map(|(a, b)| a * b).sum();
time + space
}
#[inline]
pub fn lorentz_norm_sq(x: &[f32]) -> f32 {
lorentz_inner(x, x)
}
#[inline]
pub fn project_hyperboloid(x: &[f32], c: f32) -> Vec<f32> {
let space_norm_sq: f32 = x[1..].iter().map(|v| v * v).sum();
let target = -1.0 / c;
let x0 = ((space_norm_sq - target).max(EPS)).sqrt();
let mut result = Vec::with_capacity(x.len());
result.push(x0);
result.extend_from_slice(&x[1..]);
result
}
#[inline]
pub fn lorentz_distance(x: &[f32], y: &[f32], c: f32) -> f32 {
let inner = lorentz_inner(x, y);
let arg = (-c * inner).max(1.0); arg.acosh() / c.sqrt()
}
#[inline]
pub fn busemann_score(x: &[f32], xi: &[f32]) -> f32 {
let inner = lorentz_inner(x, xi);
(-inner).max(EPS).ln()
}
pub fn horosphere_attention_weights(
query: &[f32],
keys: &[&[f32]],
focal_direction: &[f32], temperature: f32,
) -> Vec<f32> {
if keys.is_empty() {
return vec![];
}
let query_depth = busemann_score(query, focal_direction);
let scores: Vec<f32> = keys
.iter()
.map(|k| {
let key_depth = busemann_score(k, focal_direction);
-(key_depth - query_depth) / temperature
})
.collect();
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
let sum: f32 = exp_scores.iter().sum();
if sum < EPS {
vec![1.0 / keys.len() as f32; keys.len()]
} else {
exp_scores.iter().map(|&e| e / sum).collect()
}
}
pub fn einstein_midpoint(points: &[&[f32]], weights: &[f32], c: f32) -> Vec<f32> {
if points.is_empty() {
return vec![];
}
let dim = points[0].len();
let mut weighted_sum = vec![0.0f32; dim];
for (point, &weight) in points.iter().zip(weights) {
let space_norm_sq: f32 = point[1..].iter().map(|v| v * v).sum();
let gamma = 1.0 / (1.0 + c * space_norm_sq).sqrt();
let factor = weight * gamma;
for (i, &val) in point.iter().enumerate() {
weighted_sum[i] += factor * val;
}
}
project_hyperboloid(&weighted_sum, c)
}
#[derive(Debug, Clone)]
pub struct CascadeHead {
pub curvature: f32,
pub focal_direction: Vec<f32>, pub temperature: f32,
pub weight: f32, }
impl CascadeHead {
pub fn new(curvature: f32, dim: usize) -> Self {
let mut focal = vec![0.0; dim];
focal[0] = 1.0; focal[1] = 1.0;
Self {
curvature,
focal_direction: focal,
temperature: 1.0,
weight: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct LorentzCascadeAttention {
pub dim: usize,
pub heads: Vec<CascadeHead>,
pub use_simd: bool,
}
#[derive(Debug, Clone)]
pub struct LCAConfig {
pub dim: usize,
pub num_heads: usize,
pub curvature_range: (f32, f32), pub temperature: f32,
}
impl Default for LCAConfig {
fn default() -> Self {
Self {
dim: 128,
num_heads: 4,
curvature_range: (0.1, 2.0), temperature: 1.0,
}
}
}
impl LorentzCascadeAttention {
pub fn new(config: LCAConfig) -> Self {
let (c_min, c_max) = config.curvature_range;
let log_min = c_min.ln();
let log_max = c_max.ln();
let heads: Vec<CascadeHead> = (0..config.num_heads)
.map(|i| {
let t = if config.num_heads > 1 {
i as f32 / (config.num_heads - 1) as f32
} else {
0.5
};
let curvature = (log_min + t * (log_max - log_min)).exp();
let mut head = CascadeHead::new(curvature, config.dim);
head.temperature = config.temperature;
head.weight = 1.0 / config.num_heads as f32;
head
})
.collect();
Self {
dim: config.dim,
heads,
use_simd: true,
}
}
fn attend_single_head(
&self,
head: &CascadeHead,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
) -> Vec<f32> {
let query_h = project_hyperboloid(query, head.curvature);
let keys_h: Vec<Vec<f32>> = keys
.iter()
.map(|k| project_hyperboloid(k, head.curvature))
.collect();
let values_h: Vec<Vec<f32>> = values
.iter()
.map(|v| project_hyperboloid(v, head.curvature))
.collect();
let keys_refs: Vec<&[f32]> = keys_h.iter().map(|k| k.as_slice()).collect();
let weights = horosphere_attention_weights(
&query_h,
&keys_refs,
&head.focal_direction,
head.temperature,
);
let values_refs: Vec<&[f32]> = values_h.iter().map(|v| v.as_slice()).collect();
einstein_midpoint(&values_refs, &weights, head.curvature)
}
pub fn attend(&self, query: &[f32], keys: &[&[f32]], values: &[&[f32]]) -> Vec<f32> {
if keys.is_empty() || values.is_empty() {
return vec![0.0; self.dim];
}
let head_outputs: Vec<Vec<f32>> = self
.heads
.iter()
.map(|head| self.attend_single_head(head, query, keys, values))
.collect();
let mut result = vec![0.0; self.dim];
let mut total_weight = 0.0;
for (head, output) in self.heads.iter().zip(&head_outputs) {
for (i, &val) in output.iter().enumerate() {
if i < result.len() {
result[i] += head.weight * val;
}
}
total_weight += head.weight;
}
if total_weight > EPS {
for val in &mut result {
*val /= total_weight;
}
}
result
}
pub fn attend_sparse(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
top_k: usize,
) -> Vec<f32> {
if keys.len() <= top_k {
return self.attend(query, keys, values);
}
let coarse_head = &self.heads[0];
let query_h = project_hyperboloid(query, coarse_head.curvature);
let mut scored_indices: Vec<(usize, f32)> = keys
.iter()
.enumerate()
.map(|(i, k)| {
let key_h = project_hyperboloid(k, coarse_head.curvature);
let score = busemann_score(&key_h, &coarse_head.focal_direction);
(i, score)
})
.collect();
let query_score = busemann_score(&query_h, &coarse_head.focal_direction);
scored_indices.sort_by(|a, b| {
let dist_a = (a.1 - query_score).abs();
let dist_b = (b.1 - query_score).abs();
dist_a.partial_cmp(&dist_b).unwrap()
});
let selected_indices: Vec<usize> =
scored_indices.iter().take(top_k).map(|(i, _)| *i).collect();
let selected_keys: Vec<&[f32]> = selected_indices.iter().map(|&i| keys[i]).collect();
let selected_values: Vec<&[f32]> = selected_indices.iter().map(|&i| values[i]).collect();
self.attend(query, &selected_keys, &selected_values)
}
}
pub mod tangent {
use super::*;
pub fn log_map_origin(x: &[f32], c: f32) -> Vec<f32> {
let x0 = x[0];
let space = &x[1..];
let space_norm: f32 = space.iter().map(|v| v * v).sum::<f32>().sqrt();
if space_norm < EPS {
return vec![0.0; x.len() - 1];
}
let factor = (c.sqrt() * x0).acosh() / space_norm;
space.iter().map(|&v| factor * v).collect()
}
pub fn exp_map_origin(v: &[f32], c: f32) -> Vec<f32> {
let v_norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if v_norm < EPS {
let mut result = vec![0.0; v.len() + 1];
result[0] = 1.0 / c.sqrt(); return result;
}
let sqrt_c = c.sqrt();
let x0 = (sqrt_c * v_norm).cosh() / sqrt_c;
let factor = (sqrt_c * v_norm).sinh() / (sqrt_c * v_norm);
let mut result = Vec::with_capacity(v.len() + 1);
result.push(x0);
result.extend(v.iter().map(|&vi| factor * vi));
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lorentz_inner_hyperboloid() {
let point = vec![1.5430806, 1.1752012, 0.0, 0.0]; let norm_sq = lorentz_norm_sq(&point);
assert!((norm_sq + 1.0).abs() < 0.01);
}
#[test]
fn test_einstein_midpoint_two_points() {
let c = 1.0;
let p1 = project_hyperboloid(&[1.0, 0.5, 0.0], c);
let p2 = project_hyperboloid(&[1.0, -0.5, 0.0], c);
let weights = vec![0.5, 0.5];
let midpoint = einstein_midpoint(&[p1.as_slice(), p2.as_slice()], &weights, c);
let norm_sq = lorentz_norm_sq(&midpoint);
assert!((norm_sq + 1.0 / c).abs() < 0.1);
assert!(midpoint[1].abs() < 0.1);
}
#[test]
fn test_busemann_hierarchy() {
let focal = vec![1.0, -1.0, 0.0, 0.0];
let root = project_hyperboloid(&[0.0, 0.1, 0.0, 0.0], 1.0);
let leaf = project_hyperboloid(&[0.0, 0.9, 0.0, 0.0], 1.0);
let root_score = busemann_score(&root, &focal);
let leaf_score = busemann_score(&leaf, &focal);
assert!(
root_score < leaf_score,
"root_score={:.4} should be < leaf_score={:.4}\nroot={:?}, leaf={:?}",
root_score,
leaf_score,
root,
leaf
);
}
#[test]
fn test_cascade_attention_shapes() {
let config = LCAConfig {
dim: 8,
num_heads: 3,
curvature_range: (0.5, 2.0),
temperature: 1.0,
};
let lca = LorentzCascadeAttention::new(config);
let query = vec![1.0, 0.5, 0.3, 0.1, 0.0, 0.0, 0.0, 0.0];
let key1 = vec![1.0, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0];
let key2 = vec![1.0, 0.8, 0.4, 0.2, 0.0, 0.0, 0.0, 0.0];
let keys: Vec<&[f32]> = vec![&key1, &key2];
let values = keys.clone();
let output = lca.attend(&query, &keys, &values);
assert_eq!(output.len(), 8);
assert!(output.iter().all(|x| x.is_finite()));
}
#[test]
fn test_horosphere_weights_sum_to_one() {
let focal = vec![1.0, 1.0, 0.0, 0.0];
let query = project_hyperboloid(&[0.0, 0.5, 0.0, 0.0], 1.0);
let k1 = project_hyperboloid(&[0.0, 0.2, 0.0, 0.0], 1.0);
let k2 = project_hyperboloid(&[0.0, 0.6, 0.0, 0.0], 1.0);
let k3 = project_hyperboloid(&[0.0, 0.9, 0.0, 0.0], 1.0);
let keys: Vec<&[f32]> = vec![&k1, &k2, &k3];
let weights = horosphere_attention_weights(&query, &keys, &focal, 1.0);
let sum: f32 = weights.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
}
#[cfg(feature = "benchmark")]
pub mod bench {
use super::*;
use std::time::Instant;
pub fn compare_performance(n_keys: usize, dim: usize, iterations: usize) {
use crate::hyperbolic::poincare::{frechet_mean, poincare_distance};
let query: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin() * 0.5).collect();
let keys: Vec<Vec<f32>> = (0..n_keys)
.map(|j| {
(0..dim)
.map(|i| ((i + j) as f32 * 0.1).cos() * 0.5)
.collect()
})
.collect();
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let start = Instant::now();
for _ in 0..iterations {
let scores: Vec<f32> = keys_refs
.iter()
.map(|k| -poincare_distance(&query, k, 1.0))
.collect();
let _mean = frechet_mean(&keys_refs, None, 1.0, 50, 1e-5);
}
let poincare_time = start.elapsed();
let lca = LorentzCascadeAttention::new(LCAConfig {
dim,
num_heads: 4,
curvature_range: (0.1, 2.0),
temperature: 1.0,
});
let start = Instant::now();
for _ in 0..iterations {
let _output = lca.attend(&query, &keys_refs, &keys_refs);
}
let lca_time = start.elapsed();
println!(
"=== Performance Comparison (n={}, d={}, iter={}) ===",
n_keys, dim, iterations
);
println!("Poincaré Attention: {:?}", poincare_time);
println!("Lorentz Cascade: {:?}", lca_time);
println!(
"Speedup: {:.2}x",
poincare_time.as_nanos() as f64 / lca_time.as_nanos() as f64
);
}
}