use super::{HyperbolicCoherenceConfig, HyperbolicCoherenceError, NodeId, Result};
use std::collections::HashMap;
const EPS: f32 = 1e-5;
#[derive(Debug)]
pub struct HyperbolicAdapter {
config: HyperbolicCoherenceConfig,
vectors: HashMap<NodeId, Vec<f32>>,
index_built: bool,
}
impl HyperbolicAdapter {
pub fn new(config: HyperbolicCoherenceConfig) -> Self {
Self {
config,
vectors: HashMap::new(),
index_built: false,
}
}
pub fn project_to_ball(&self, vector: &[f32]) -> Result<Vec<f32>> {
let norm_sq: f32 = vector.iter().map(|x| x * x).sum();
let norm = norm_sq.sqrt();
if norm < 1.0 - self.config.epsilon {
return Ok(vector.to_vec());
}
let max_norm = 1.0 - self.config.epsilon;
let scale = max_norm / (norm + EPS);
let projected: Vec<f32> = vector.iter().map(|x| x * scale).collect();
Ok(projected)
}
pub fn insert(&mut self, node_id: NodeId, vector: Vec<f32>) -> Result<()> {
self.vectors.insert(node_id, vector);
self.index_built = false; Ok(())
}
pub fn update(&mut self, node_id: NodeId, vector: Vec<f32>) -> Result<()> {
if !self.vectors.contains_key(&node_id) {
return Err(HyperbolicCoherenceError::NodeNotFound(node_id));
}
self.vectors.insert(node_id, vector);
self.index_built = false;
Ok(())
}
pub fn get(&self, node_id: NodeId) -> Option<&Vec<f32>> {
self.vectors.get(&node_id)
}
pub fn poincare_distance(&self, x: &[f32], y: &[f32]) -> f32 {
let c = -self.config.curvature;
let norm_x_sq: f32 = x.iter().map(|v| v * v).sum();
let norm_y_sq: f32 = y.iter().map(|v| v * v).sum();
let diff_sq: f32 = x
.iter()
.zip(y.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
let denom = (1.0 - norm_x_sq).max(EPS) * (1.0 - norm_y_sq).max(EPS);
let inner = 1.0 + 2.0 * diff_sq / denom;
let acosh_inner = if inner >= 1.0 {
(inner + (inner * inner - 1.0).sqrt()).ln()
} else {
0.0
};
acosh_inner / c.sqrt()
}
pub fn frechet_mean(&self, points: &[&Vec<f32>]) -> Result<Vec<f32>> {
if points.is_empty() {
return Err(HyperbolicCoherenceError::EmptyCollection);
}
if points.len() == 1 {
return Ok(points[0].clone());
}
let dim = points[0].len();
let mut mean: Vec<f32> = vec![0.0; dim];
for p in points {
for (m, &v) in mean.iter_mut().zip(p.iter()) {
*m += v;
}
}
for m in mean.iter_mut() {
*m /= points.len() as f32;
}
mean = self.project_to_ball(&mean)?;
for _ in 0..self.config.frechet_max_iters {
let mut grad = vec![0.0f32; dim];
let mut total_dist = 0.0f32;
for &p in points {
let log = self.log_map(&mean, p);
for (g, l) in grad.iter_mut().zip(log.iter()) {
*g += l;
}
total_dist += self.poincare_distance(&mean, p);
}
for g in grad.iter_mut() {
*g /= points.len() as f32;
}
let grad_norm: f32 = grad.iter().map(|x| x * x).sum::<f32>().sqrt();
if grad_norm < self.config.frechet_tolerance {
break;
}
let step_size = 0.1f32.min(1.0 / (total_dist + 1.0));
let step: Vec<f32> = grad.iter().map(|g| g * step_size).collect();
mean = self.exp_map(&mean, &step)?;
mean = self.project_to_ball(&mean)?;
}
Ok(mean)
}
fn log_map(&self, base: &[f32], point: &[f32]) -> Vec<f32> {
let c = -self.config.curvature;
let diff: Vec<f32> = point.iter().zip(base.iter()).map(|(p, b)| p - b).collect();
let diff_norm: f32 = diff.iter().map(|x| x * x).sum::<f32>().sqrt().max(EPS);
let base_norm_sq: f32 = base.iter().map(|x| x * x).sum();
let lambda_base = 2.0 / (1.0 - base_norm_sq).max(EPS);
let dist = self.poincare_distance(base, point);
let scale = dist * lambda_base.sqrt() / (c.sqrt() * diff_norm);
diff.iter().map(|d| d * scale).collect()
}
fn exp_map(&self, base: &[f32], tangent: &[f32]) -> Result<Vec<f32>> {
let c = -self.config.curvature;
let tangent_norm: f32 = tangent.iter().map(|x| x * x).sum::<f32>().sqrt();
if tangent_norm < EPS {
return Ok(base.to_vec());
}
let base_norm_sq: f32 = base.iter().map(|x| x * x).sum();
let lambda_base = 2.0 / (1.0 - base_norm_sq).max(EPS);
let normalized: Vec<f32> = tangent.iter().map(|t| t / tangent_norm).collect();
let scaled_norm = tangent_norm / lambda_base.sqrt();
let tanh_arg = c.sqrt() * scaled_norm;
let tanh_val = tanh_arg.tanh();
let scale = tanh_val / c.sqrt();
let mut result: Vec<f32> = base.to_vec();
for (r, n) in result.iter_mut().zip(normalized.iter()) {
*r += scale * n;
}
self.project_to_ball(&result)
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> {
if self.vectors.is_empty() {
return Ok(vec![]);
}
let mut distances: Vec<(NodeId, f32)> = self
.vectors
.iter()
.map(|(&id, vec)| (id, self.poincare_distance(query, vec)))
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
distances.truncate(k);
Ok(distances)
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_projection() {
let config = HyperbolicCoherenceConfig {
dimension: 4,
curvature: -1.0,
..Default::default()
};
let adapter = HyperbolicAdapter::new(config);
let inside = vec![0.1, 0.1, 0.1, 0.1];
let projected = adapter.project_to_ball(&inside).unwrap();
assert!((projected[0] - inside[0]).abs() < 0.01);
let outside = vec![0.9, 0.9, 0.9, 0.9];
let projected = adapter.project_to_ball(&outside).unwrap();
let norm: f32 = projected.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(norm < 1.0);
}
#[test]
fn test_poincare_distance() {
let config = HyperbolicCoherenceConfig {
dimension: 4,
curvature: -1.0,
..Default::default()
};
let adapter = HyperbolicAdapter::new(config);
let origin = vec![0.0, 0.0, 0.0, 0.0];
let point = vec![0.5, 0.0, 0.0, 0.0];
let dist = adapter.poincare_distance(&origin, &point);
assert!(dist > 0.0);
let self_dist = adapter.poincare_distance(&point, &point);
assert!(self_dist < 0.01);
}
#[test]
fn test_frechet_mean() {
let config = HyperbolicCoherenceConfig {
dimension: 4,
curvature: -1.0,
..Default::default()
};
let adapter = HyperbolicAdapter::new(config);
let points = vec![
vec![0.1, 0.0, 0.0, 0.0],
vec![-0.1, 0.0, 0.0, 0.0],
vec![0.0, 0.1, 0.0, 0.0],
vec![0.0, -0.1, 0.0, 0.0],
];
let refs: Vec<&Vec<f32>> = points.iter().collect();
let mean = adapter.frechet_mean(&refs).unwrap();
let mean_norm: f32 = mean.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(mean_norm < 0.1);
}
#[test]
fn test_search() {
let config = HyperbolicCoherenceConfig {
dimension: 4,
curvature: -1.0,
..Default::default()
};
let mut adapter = HyperbolicAdapter::new(config);
adapter.insert(1, vec![0.1, 0.0, 0.0, 0.0]).unwrap();
adapter.insert(2, vec![0.2, 0.0, 0.0, 0.0]).unwrap();
adapter.insert(3, vec![0.5, 0.0, 0.0, 0.0]).unwrap();
let query = vec![0.15, 0.0, 0.0, 0.0];
let results = adapter.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
assert!(results[0].0 == 1 || results[0].0 == 2);
}
}