1use serde::{Deserialize, Serialize};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct HyperbolicConfig {
25 pub curvature: f32,
28 pub dimensions: usize,
30 pub epsilon: f32,
32}
33
34impl Default for HyperbolicConfig {
35 fn default() -> Self {
36 Self {
37 curvature: -1.0,
38 dimensions: 64,
39 epsilon: 1e-5,
40 }
41 }
42}
43
44impl HyperbolicConfig {
45 pub fn new(curvature: f32, dimensions: usize) -> Self {
47 assert!(
48 curvature < 0.0,
49 "Curvature must be negative for hyperbolic space"
50 );
51 Self {
52 curvature,
53 dimensions,
54 epsilon: 1e-5,
55 }
56 }
57
58 #[allow(dead_code)]
60 fn c(&self) -> f32 {
61 self.curvature.abs()
62 }
63}
64
65pub fn euclidean_to_poincare(vector: &[f32], curvature: f32) -> Vec<f32> {
81 let c = curvature.abs();
82 let max_norm = 1.0 / c.sqrt();
83
84 let norm_sq: f32 = vector.iter().map(|v| v * v).sum();
86 let norm = norm_sq.sqrt();
87
88 if norm == 0.0 {
89 return vec![0.0; vector.len()];
90 }
91
92 let scale = max_norm * norm.tanh() / norm;
95 vector.iter().map(|&v| v * scale).collect()
96}
97
98pub fn batch_euclidean_to_poincare(vectors: &[Vec<f32>], curvature: f32) -> Vec<Vec<f32>> {
100 vectors
101 .iter()
102 .map(|v| euclidean_to_poincare(v, curvature))
103 .collect()
104}
105
106pub fn hyperbolic_distance(a: &[f32], b: &[f32], curvature: f32) -> f32 {
112 let c = curvature.abs();
113
114 let norm_a_sq: f32 = a.iter().map(|v| v * v).sum();
115 let norm_b_sq: f32 = b.iter().map(|v| v * v).sum();
116
117 let diff_sq: f32 = a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum();
118
119 let denominator = (1.0 - c * norm_a_sq) * (1.0 - c * norm_b_sq);
120
121 if denominator <= 0.0 {
122 return f32::MAX;
124 }
125
126 let arg = 1.0 + 2.0 * c * diff_sq / denominator;
127
128 if arg <= 1.0 {
129 return 0.0;
131 }
132
133 (1.0 / c.sqrt()) * arg.ln().max(0.0).sqrt()
134}
135
136pub fn mobius_add(a: &[f32], b: &[f32], curvature: f32) -> Vec<f32> {
141 let c = curvature.abs();
142
143 let norm_a_sq: f32 = a.iter().map(|v| v * v).sum();
144 let norm_b_sq: f32 = b.iter().map(|v| v * v).sum();
145 let dot_ab: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
146
147 let denominator = 1.0 + 2.0 * c * dot_ab + c * c * norm_a_sq * norm_b_sq;
148
149 if denominator.abs() < 1e-10 {
150 return vec![0.0; a.len()];
152 }
153
154 let scale_a = 1.0 + 2.0 * c * dot_ab + c * norm_b_sq;
155 let scale_b = 1.0 - c * norm_a_sq;
156
157 a.iter()
158 .zip(b)
159 .map(|(&ai, &bi)| (scale_a * ai + scale_b * bi) / denominator)
160 .collect()
161}
162
163pub fn mobius_scalar_mul(scalar: f32, v: &[f32], curvature: f32, epsilon: f32) -> Vec<f32> {
173 let c = curvature.abs();
174 let norm_sq: f32 = v.iter().map(|x| x * x).sum();
175 let norm = norm_sq.sqrt();
176
177 if norm < epsilon {
178 return vec![0.0; v.len()];
179 }
180
181 let c_sqrt = c.sqrt();
182 let w = c_sqrt * norm;
183
184 let w = w.min(1.0 - epsilon);
186 let result_norm = (1.0 / c_sqrt) * (scalar * w.atanh()).tanh();
187
188 let scale = result_norm / norm;
189 v.iter().map(|&vi| vi * scale).collect()
190}
191
192pub struct HyperbolicEmbedding {
201 config: HyperbolicConfig,
202 embeddings: Vec<(String, Vec<f32>)>,
204}
205
206impl HyperbolicEmbedding {
207 pub fn new(config: HyperbolicConfig) -> Self {
209 Self {
210 config,
211 embeddings: Vec::new(),
212 }
213 }
214
215 pub fn with_dimensions(dimensions: usize) -> Self {
217 let config = HyperbolicConfig {
218 dimensions,
219 ..Default::default()
220 };
221 Self::new(config)
222 }
223
224 pub fn add(&mut self, id: &str, euclidean: &[f32]) {
228 let poincare = euclidean_to_poincare(euclidean, self.config.curvature);
229 if let Some(pos) = self.embeddings.iter().position(|(name, _)| name == id) {
231 self.embeddings[pos] = (id.to_string(), poincare);
232 } else {
233 self.embeddings.push((id.to_string(), poincare));
234 }
235 }
236
237 pub fn add_child(&mut self, parent_id: &str, child_id: &str, child_euclidean: &[f32]) {
242 let child_on_ball = euclidean_to_poincare(child_euclidean, self.config.curvature);
243
244 let child_point = if let Some((_, parent_vec)) =
245 self.embeddings.iter().find(|(name, _)| name == parent_id)
246 {
247 mobius_add(parent_vec, &child_on_ball, self.config.curvature)
250 } else {
251 child_on_ball
252 };
253
254 if let Some(pos) = self
255 .embeddings
256 .iter()
257 .position(|(name, _)| name == child_id)
258 {
259 self.embeddings[pos] = (child_id.to_string(), child_point);
260 } else {
261 self.embeddings.push((child_id.to_string(), child_point));
262 }
263 }
264
265 pub fn get(&self, id: &str) -> Option<&[f32]> {
267 self.embeddings
268 .iter()
269 .find(|(name, _)| name == id)
270 .map(|(_, v)| v.as_slice())
271 }
272
273 pub fn nearest_neighbors(&self, query_id: &str, k: usize) -> Vec<(String, f32)> {
277 let query = match self.get(query_id) {
278 Some(v) => v.to_vec(),
279 None => return Vec::new(),
280 };
281
282 let mut results: Vec<(String, f32)> = self
283 .embeddings
284 .iter()
285 .filter(|(name, _)| name != query_id)
286 .map(|(name, vec)| {
287 let dist = hyperbolic_distance(&query, vec, self.config.curvature);
288 (name.clone(), dist)
289 })
290 .collect();
291
292 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
293 results.truncate(k);
294 results
295 }
296
297 pub fn search(&self, query: &[f32], k: usize) -> Vec<(String, f32)> {
299 let query_poincare = euclidean_to_poincare(query, self.config.curvature);
300
301 let mut results: Vec<(String, f32)> = self
302 .embeddings
303 .iter()
304 .map(|(name, vec)| {
305 let dist = hyperbolic_distance(&query_poincare, vec, self.config.curvature);
306 (name.clone(), dist)
307 })
308 .collect();
309
310 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
311 results.truncate(k);
312 results
313 }
314
315 pub fn hierarchical_distance(&self, id_a: &str, id_b: &str) -> f32 {
321 let a = match self.get(id_a) {
322 Some(v) => v,
323 None => return f32::MAX,
324 };
325 let b = match self.get(id_b) {
326 Some(v) => v,
327 None => return f32::MAX,
328 };
329
330 hyperbolic_distance(a, b, self.config.curvature)
331 }
332
333 pub fn len(&self) -> usize {
335 self.embeddings.len()
336 }
337
338 pub fn is_empty(&self) -> bool {
340 self.embeddings.is_empty()
341 }
342
343 pub fn ids(&self) -> Vec<&str> {
345 self.embeddings
346 .iter()
347 .map(|(name, _)| name.as_str())
348 .collect()
349 }
350
351 pub fn depth(&self, id: &str) -> f32 {
355 match self.get(id) {
356 Some(v) => hyperbolic_distance(&vec![0.0; v.len()], v, self.config.curvature),
357 None => f32::MAX,
358 }
359 }
360
361 pub fn rank_by_depth(&self) -> Vec<(String, f32)> {
366 let mut ranked: Vec<(String, f32)> = self
367 .embeddings
368 .iter()
369 .map(|(name, vec)| {
370 let origin = vec![0.0; vec.len()];
371 let d = hyperbolic_distance(&origin, vec, self.config.curvature);
372 (name.clone(), d)
373 })
374 .collect();
375
376 ranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
377 ranked
378 }
379}
380
381#[cfg(test)]
386mod tests {
387 use super::*;
388
389 #[test]
390 fn test_euclidean_to_poincare_zero() {
391 let result = euclidean_to_poincare(&[0.0, 0.0, 0.0], -1.0);
392 assert_eq!(result, vec![0.0, 0.0, 0.0]);
393 }
394
395 #[test]
396 fn test_euclidean_to_poincare_bounded() {
397 let c = -1.0;
398 let result = euclidean_to_poincare(&[100.0, 100.0, 100.0], c);
400 let norm: f32 = result.iter().map(|v| v * v).sum::<f32>().sqrt();
401 let max_norm = 1.0 / c.abs().sqrt();
402 assert!(
403 norm < max_norm,
404 "Result should be inside the ball: norm={}, max={}",
405 norm,
406 max_norm
407 );
408 }
409
410 #[test]
411 fn test_hyperbolic_distance_same_point() {
412 let point = euclidean_to_poincare(&[0.5, 0.3], -1.0);
413 let dist = hyperbolic_distance(&point, &point, -1.0);
414 assert!(dist < 1e-5, "Distance from self should be ~0, got {}", dist);
415 }
416
417 #[test]
418 fn test_hyperbolic_distance_symmetry() {
419 let a = euclidean_to_poincare(&[1.0, 2.0], -1.0);
420 let b = euclidean_to_poincare(&[3.0, 1.0], -1.0);
421 let d_ab = hyperbolic_distance(&a, &b, -1.0);
422 let d_ba = hyperbolic_distance(&b, &a, -1.0);
423 assert!(
424 (d_ab - d_ba).abs() < 1e-4,
425 "Distance should be symmetric: {} vs {}",
426 d_ab,
427 d_ba
428 );
429 }
430
431 #[test]
432 fn test_hyperbolic_distance_triangle_inequality() {
433 let a = euclidean_to_poincare(&[1.0, 0.0], -1.0);
434 let b = euclidean_to_poincare(&[0.0, 1.0], -1.0);
435 let c = euclidean_to_poincare(&[2.0, 2.0], -1.0);
436
437 let d_ab = hyperbolic_distance(&a, &b, -1.0);
438 let d_bc = hyperbolic_distance(&b, &c, -1.0);
439 let d_ac = hyperbolic_distance(&a, &c, -1.0);
440
441 assert!(
442 d_ac <= d_ab + d_bc + 1e-4,
443 "Triangle inequality: d(a,c)={} should be <= d(a,b)+d(b,c)={}",
444 d_ac,
445 d_ab + d_bc
446 );
447 }
448
449 #[test]
450 fn test_mobius_add_identity() {
451 let a = euclidean_to_poincare(&[0.5, 0.3], -1.0);
452 let zero = vec![0.0, 0.0];
453 let result = mobius_add(&a, &zero, -1.0);
454 for (r, expected) in result.iter().zip(a.iter()) {
455 assert!((r - expected).abs() < 1e-4, "a ⊕ 0 should equal a");
456 }
457 }
458
459 #[test]
460 fn test_mobius_scalar_mul_zero() {
461 let v = euclidean_to_poincare(&[1.0, 2.0], -1.0);
462 let result = mobius_scalar_mul(0.0, &v, -1.0, 1e-5);
463 for r in &result {
464 assert!(r.abs() < 1e-4, "0 ⊗ v should be ~0, got {}", r);
465 }
466 }
467
468 #[test]
469 fn test_mobius_scalar_mul_one() {
470 let v = euclidean_to_poincare(&[1.0, 2.0], -1.0);
471 let result = mobius_scalar_mul(1.0, &v, -1.0, 1e-5);
472 for (r, expected) in result.iter().zip(v.iter()) {
473 assert!((r - expected).abs() < 1e-4, "1 ⊗ v should equal v");
474 }
475 }
476
477 #[test]
478 fn test_hyperbolic_embedding_add_and_search() {
479 let mut he = HyperbolicEmbedding::with_dimensions(3);
480
481 he.add("root", &[0.0, 0.0, 0.0]);
482 he.add("child_a", &[1.0, 0.0, 0.0]);
483 he.add("child_b", &[0.0, 1.0, 0.0]);
484 he.add("grandchild", &[1.0, 1.0, 0.0]);
485
486 assert_eq!(he.len(), 4);
487
488 let nn = he.nearest_neighbors("child_a", 2);
490 assert_eq!(nn.len(), 2);
491 let gc_dist = nn
493 .iter()
494 .find(|(name, _)| name == "grandchild")
495 .map(|(_, d)| *d);
496 let cb_dist = nn
497 .iter()
498 .find(|(name, _)| name == "child_b")
499 .map(|(_, d)| *d);
500 if let (Some(gc), Some(cb)) = (gc_dist, cb_dist) {
501 assert!(
502 gc < cb,
503 "grandchild should be closer to child_a than child_b"
504 );
505 }
506 }
507
508 #[test]
509 fn test_hyperbolic_embedding_depth() {
510 let mut he = HyperbolicEmbedding::with_dimensions(2);
511
512 he.add("root", &[0.0, 0.0]);
513 he.add("level1", &[0.5, 0.0]);
514 he.add("level2", &[1.0, 0.0]);
515
516 let root_depth = he.depth("root");
517 let l1_depth = he.depth("level1");
518 let l2_depth = he.depth("level2");
519
520 assert!(
521 root_depth < l1_depth,
522 "Root should be shallower: root={}, l1={}",
523 root_depth,
524 l1_depth
525 );
526 assert!(
527 l1_depth < l2_depth,
528 "Level1 should be shallower: l1={}, l2={}",
529 l1_depth,
530 l2_depth
531 );
532 }
533
534 #[test]
535 fn test_rank_by_depth() {
536 let mut he = HyperbolicEmbedding::with_dimensions(2);
537
538 he.add("leaf", &[2.0, 2.0]);
539 he.add("root", &[0.0, 0.0]);
540 he.add("mid", &[0.5, 0.5]);
541
542 let ranked = he.rank_by_depth();
543 assert_eq!(ranked[0].0, "root");
544 assert_eq!(ranked[1].0, "mid");
545 assert_eq!(ranked[2].0, "leaf");
546 }
547
548 #[test]
549 fn test_batch_conversion() {
550 let vectors = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![0.0, 0.0]];
551 let results = batch_euclidean_to_poincare(&vectors, -1.0);
552 assert_eq!(results.len(), 3);
553 assert_eq!(results[2], vec![0.0, 0.0]);
555 }
556
557 #[test]
558 fn test_curvature_effect() {
559 let v = [1.0, 1.0];
560
561 let p1 = euclidean_to_poincare(&v, -1.0);
562 let p2 = euclidean_to_poincare(&v, -2.0);
563
564 let norm1: f32 = p1.iter().map(|x| x * x).sum::<f32>().sqrt();
565 let norm2: f32 = p2.iter().map(|x| x * x).sum::<f32>().sqrt();
566
567 assert!(
569 norm2 < norm1,
570 "Higher curvature should produce smaller ball: {} vs {}",
571 norm2,
572 norm1
573 );
574 }
575
576 #[test]
577 fn test_add_child_hierarchy() {
578 let mut he = HyperbolicEmbedding::with_dimensions(3);
579
580 he.add("parent", &[1.0, 0.0, 0.0]);
582 he.add_child("parent", "child", &[0.5, 0.5, 0.0]);
583
584 assert_eq!(he.len(), 2);
585
586 assert!(he.get("parent").is_some());
588 assert!(he.get("child").is_some());
589 }
590}