Skip to main content

oxios_kernel/memory/
hyperbolic.rs

1//! Hyperbolic embeddings using the Poincaré ball model.
2//!
3//! The Poincaré ball model embeds hierarchical data (trees, taxonomies,
4//! ontologies) in hyperbolic space where distances naturally encode
5//! hierarchical relationships. Nodes close to the root are near the
6//! origin; leaf nodes are near the boundary.
7//!
8//! Use cases in Oxios:
9//! - Persona hierarchy (parent → child relationships)
10//! - Skill graph (prerequisite chains)
11//! - Memory category taxonomy
12//!
13//! Reference: "Poincaré Embeddings for Learning Hierarchical
14//! Representations" (Nickel & Kiela, 2017)
15
16use serde::{Deserialize, Serialize};
17
18// ---------------------------------------------------------------------------
19// Configuration
20// ---------------------------------------------------------------------------
21
22/// Configuration for hyperbolic operations.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct HyperbolicConfig {
25    /// Curvature of the hyperbolic space.
26    /// Must be negative. Default: -1.0 (standard Poincaré ball).
27    pub curvature: f32,
28    /// Embedding dimensionality.
29    pub dimensions: usize,
30    /// Numerical stability epsilon.
31    pub epsilon: f32,
32}
33
34impl Default for HyperbolicConfig {
35    fn default() -> Self {
36        Self {
37            curvature: -1.0,
38            dimensions: 64,
39            epsilon: 1e-5,
40        }
41    }
42}
43
44impl HyperbolicConfig {
45    /// Create a new config with validation.
46    pub fn new(curvature: f32, dimensions: usize) -> Self {
47        assert!(
48            curvature < 0.0,
49            "Curvature must be negative for hyperbolic space"
50        );
51        Self {
52            curvature,
53            dimensions,
54            epsilon: 1e-5,
55        }
56    }
57
58    /// Returns the absolute value of curvature (c = |K|).
59    #[allow(dead_code)]
60    fn c(&self) -> f32 {
61        self.curvature.abs()
62    }
63}
64
65// ---------------------------------------------------------------------------
66// Poincaré ball operations
67// ---------------------------------------------------------------------------
68
69/// Convert a Euclidean vector to a point on the Poincaré ball.
70///
71/// Projects the vector onto the open unit ball with radius 1/√c.
72/// Points are clipped to stay strictly inside the ball.
73///
74/// # Arguments
75/// * `vector` - Euclidean vector
76/// * `curvature` - Negative curvature K (e.g., -1.0)
77///
78/// # Returns
79/// Point on the Poincaré ball
80pub fn euclidean_to_poincare(vector: &[f32], curvature: f32) -> Vec<f32> {
81    let c = curvature.abs();
82    let max_norm = 1.0 / c.sqrt();
83
84    // Compute Euclidean norm
85    let norm_sq: f32 = vector.iter().map(|v| v * v).sum();
86    let norm = norm_sq.sqrt();
87
88    if norm == 0.0 {
89        return vec![0.0; vector.len()];
90    }
91
92    // Map to ball: project and scale, keeping inside the boundary
93    // Use tanh-based mapping for smooth bounded projection
94    let scale = max_norm * norm.tanh() / norm;
95    vector.iter().map(|&v| v * scale).collect()
96}
97
98/// Batch-convert Euclidean vectors to Poincaré ball points.
99pub fn batch_euclidean_to_poincare(vectors: &[Vec<f32>], curvature: f32) -> Vec<Vec<f32>> {
100    vectors
101        .iter()
102        .map(|v| euclidean_to_poincare(v, curvature))
103        .collect()
104}
105
106/// Compute the hyperbolic distance between two points on the Poincaré ball.
107///
108/// d(x, y) = (1/√c) * arcosh(1 + 2c * δ(x, y) / ((1 - c||x||²)(1 - c||y||²)))
109///
110/// where δ(x, y) = ||x - y||²
111pub fn hyperbolic_distance(a: &[f32], b: &[f32], curvature: f32) -> f32 {
112    let c = curvature.abs();
113
114    let norm_a_sq: f32 = a.iter().map(|v| v * v).sum();
115    let norm_b_sq: f32 = b.iter().map(|v| v * v).sum();
116
117    let diff_sq: f32 = a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum();
118
119    let denominator = (1.0 - c * norm_a_sq) * (1.0 - c * norm_b_sq);
120
121    if denominator <= 0.0 {
122        // Points on or beyond the boundary — return max distance
123        return f32::MAX;
124    }
125
126    let arg = 1.0 + 2.0 * c * diff_sq / denominator;
127
128    if arg <= 1.0 {
129        // Same point or very close
130        return 0.0;
131    }
132
133    (1.0 / c.sqrt()) * arg.ln().max(0.0).sqrt()
134}
135
136/// Möbius addition: the hyperbolic analog of vector addition.
137///
138/// a ⊕_c b = ((1 + 2c⟨a,b⟩ + c||b||²)a + (1 - c||a||²)b) /
139///           (1 + 2c⟨a,b⟩ + c²||a||²||b||²)
140pub fn mobius_add(a: &[f32], b: &[f32], curvature: f32) -> Vec<f32> {
141    let c = curvature.abs();
142
143    let norm_a_sq: f32 = a.iter().map(|v| v * v).sum();
144    let norm_b_sq: f32 = b.iter().map(|v| v * v).sum();
145    let dot_ab: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
146
147    let denominator = 1.0 + 2.0 * c * dot_ab + c * c * norm_a_sq * norm_b_sq;
148
149    if denominator.abs() < 1e-10 {
150        // Degenerate case: return origin (neutral element)
151        return vec![0.0; a.len()];
152    }
153
154    let scale_a = 1.0 + 2.0 * c * dot_ab + c * norm_b_sq;
155    let scale_b = 1.0 - c * norm_a_sq;
156
157    a.iter()
158        .zip(b)
159        .map(|(&ai, &bi)| (scale_a * ai + scale_b * bi) / denominator)
160        .collect()
161}
162
163/// Möbius scalar multiplication: scaling in hyperbolic space.
164///
165/// s ⊗_c v = (1/√c) * tanh(s * arctanh(√c * ||v||)) * v / ||v||
166///
167/// # Arguments
168/// * `scalar` - Multiplication factor
169/// * `v` - Vector on the Poincaré ball
170/// * `curvature` - Negative curvature K
171/// * `epsilon` - Numerical stability margin (e.g. 1e-5)
172pub fn mobius_scalar_mul(scalar: f32, v: &[f32], curvature: f32, epsilon: f32) -> Vec<f32> {
173    let c = curvature.abs();
174    let norm_sq: f32 = v.iter().map(|x| x * x).sum();
175    let norm = norm_sq.sqrt();
176
177    if norm < epsilon {
178        return vec![0.0; v.len()];
179    }
180
181    let c_sqrt = c.sqrt();
182    let w = c_sqrt * norm;
183
184    // Clamp w to strictly less than 1 for numerical stability
185    let w = w.min(1.0 - epsilon);
186    let result_norm = (1.0 / c_sqrt) * (scalar * w.atanh()).tanh();
187
188    let scale = result_norm / norm;
189    v.iter().map(|&vi| vi * scale).collect()
190}
191
192// ---------------------------------------------------------------------------
193// HyperbolicEmbedding — higher-level interface
194// ---------------------------------------------------------------------------
195
196/// Hyperbolic embedding manager for hierarchical data.
197///
198/// Provides a convenient interface for storing and querying
199/// hierarchical embeddings in Poincaré ball space.
200pub struct HyperbolicEmbedding {
201    config: HyperbolicConfig,
202    /// Named embeddings: id → Poincaré ball point.
203    embeddings: Vec<(String, Vec<f32>)>,
204}
205
206impl HyperbolicEmbedding {
207    /// Create a new hyperbolic embedding manager.
208    pub fn new(config: HyperbolicConfig) -> Self {
209        Self {
210            config,
211            embeddings: Vec::new(),
212        }
213    }
214
215    /// Create with default configuration.
216    pub fn with_dimensions(dimensions: usize) -> Self {
217        let config = HyperbolicConfig {
218            dimensions,
219            ..Default::default()
220        };
221        Self::new(config)
222    }
223
224    /// Add a Euclidean vector as a named embedding.
225    ///
226    /// Converts to Poincaré ball coordinates.
227    pub fn add(&mut self, id: &str, euclidean: &[f32]) {
228        let poincare = euclidean_to_poincare(euclidean, self.config.curvature);
229        // Replace if exists
230        if let Some(pos) = self.embeddings.iter().position(|(name, _)| name == id) {
231            self.embeddings[pos] = (id.to_string(), poincare);
232        } else {
233            self.embeddings.push((id.to_string(), poincare));
234        }
235    }
236
237    /// Add a parent-child relationship using Möbius addition.
238    ///
239    /// The child is placed at `parent ⊕ child_euclidean`, which naturally
240    /// positions it farther from the origin along the parent's direction.
241    pub fn add_child(&mut self, parent_id: &str, child_id: &str, child_euclidean: &[f32]) {
242        let child_on_ball = euclidean_to_poincare(child_euclidean, self.config.curvature);
243
244        let child_point = if let Some((_, parent_vec)) =
245            self.embeddings.iter().find(|(name, _)| name == parent_id)
246        {
247            // Use Möbius addition: child = parent ⊕ child_offset
248            // This naturally places the child deeper in the hierarchy
249            mobius_add(parent_vec, &child_on_ball, self.config.curvature)
250        } else {
251            child_on_ball
252        };
253
254        if let Some(pos) = self
255            .embeddings
256            .iter()
257            .position(|(name, _)| name == child_id)
258        {
259            self.embeddings[pos] = (child_id.to_string(), child_point);
260        } else {
261            self.embeddings.push((child_id.to_string(), child_point));
262        }
263    }
264
265    /// Get the hyperbolic embedding for a given id.
266    pub fn get(&self, id: &str) -> Option<&[f32]> {
267        self.embeddings
268            .iter()
269            .find(|(name, _)| name == id)
270            .map(|(_, v)| v.as_slice())
271    }
272
273    /// Find the k nearest neighbors in hyperbolic space.
274    ///
275    /// Returns (id, distance) pairs sorted by distance.
276    pub fn nearest_neighbors(&self, query_id: &str, k: usize) -> Vec<(String, f32)> {
277        let query = match self.get(query_id) {
278            Some(v) => v.to_vec(),
279            None => return Vec::new(),
280        };
281
282        let mut results: Vec<(String, f32)> = self
283            .embeddings
284            .iter()
285            .filter(|(name, _)| name != query_id)
286            .map(|(name, vec)| {
287                let dist = hyperbolic_distance(&query, vec, self.config.curvature);
288                (name.clone(), dist)
289            })
290            .collect();
291
292        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
293        results.truncate(k);
294        results
295    }
296
297    /// Find nearest neighbors for an arbitrary Euclidean query.
298    pub fn search(&self, query: &[f32], k: usize) -> Vec<(String, f32)> {
299        let query_poincare = euclidean_to_poincare(query, self.config.curvature);
300
301        let mut results: Vec<(String, f32)> = self
302            .embeddings
303            .iter()
304            .map(|(name, vec)| {
305                let dist = hyperbolic_distance(&query_poincare, vec, self.config.curvature);
306                (name.clone(), dist)
307            })
308            .collect();
309
310        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
311        results.truncate(k);
312        results
313    }
314
315    /// Compute the hierarchical distance between two embeddings.
316    ///
317    /// In hierarchical data, nodes deeper in the tree are farther from
318    /// the origin. This function returns the hyperbolic distance plus
319    /// a depth penalty.
320    pub fn hierarchical_distance(&self, id_a: &str, id_b: &str) -> f32 {
321        let a = match self.get(id_a) {
322            Some(v) => v,
323            None => return f32::MAX,
324        };
325        let b = match self.get(id_b) {
326            Some(v) => v,
327            None => return f32::MAX,
328        };
329
330        hyperbolic_distance(a, b, self.config.curvature)
331    }
332
333    /// Returns the number of stored embeddings.
334    pub fn len(&self) -> usize {
335        self.embeddings.len()
336    }
337
338    /// Returns true if no embeddings stored.
339    pub fn is_empty(&self) -> bool {
340        self.embeddings.is_empty()
341    }
342
343    /// Returns all embedding ids.
344    pub fn ids(&self) -> Vec<&str> {
345        self.embeddings
346            .iter()
347            .map(|(name, _)| name.as_str())
348            .collect()
349    }
350
351    /// Get the hyperbolic distance of a point from the origin.
352    ///
353    /// Points closer to the origin are "higher" in the hierarchy.
354    pub fn depth(&self, id: &str) -> f32 {
355        match self.get(id) {
356            Some(v) => hyperbolic_distance(&vec![0.0; v.len()], v, self.config.curvature),
357            None => f32::MAX,
358        }
359    }
360
361    /// Rank all embeddings by depth (origin distance).
362    ///
363    /// Returns (id, depth) pairs sorted by depth ascending.
364    /// Items with lower depth are closer to the root of the hierarchy.
365    pub fn rank_by_depth(&self) -> Vec<(String, f32)> {
366        let mut ranked: Vec<(String, f32)> = self
367            .embeddings
368            .iter()
369            .map(|(name, vec)| {
370                let origin = vec![0.0; vec.len()];
371                let d = hyperbolic_distance(&origin, vec, self.config.curvature);
372                (name.clone(), d)
373            })
374            .collect();
375
376        ranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
377        ranked
378    }
379}
380
381// ---------------------------------------------------------------------------
382// Tests
383// ---------------------------------------------------------------------------
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    #[test]
390    fn test_euclidean_to_poincare_zero() {
391        let result = euclidean_to_poincare(&[0.0, 0.0, 0.0], -1.0);
392        assert_eq!(result, vec![0.0, 0.0, 0.0]);
393    }
394
395    #[test]
396    fn test_euclidean_to_poincare_bounded() {
397        let c = -1.0;
398        // Large vector should be projected inside the ball
399        let result = euclidean_to_poincare(&[100.0, 100.0, 100.0], c);
400        let norm: f32 = result.iter().map(|v| v * v).sum::<f32>().sqrt();
401        let max_norm = 1.0 / c.abs().sqrt();
402        assert!(
403            norm < max_norm,
404            "Result should be inside the ball: norm={}, max={}",
405            norm,
406            max_norm
407        );
408    }
409
410    #[test]
411    fn test_hyperbolic_distance_same_point() {
412        let point = euclidean_to_poincare(&[0.5, 0.3], -1.0);
413        let dist = hyperbolic_distance(&point, &point, -1.0);
414        assert!(dist < 1e-5, "Distance from self should be ~0, got {}", dist);
415    }
416
417    #[test]
418    fn test_hyperbolic_distance_symmetry() {
419        let a = euclidean_to_poincare(&[1.0, 2.0], -1.0);
420        let b = euclidean_to_poincare(&[3.0, 1.0], -1.0);
421        let d_ab = hyperbolic_distance(&a, &b, -1.0);
422        let d_ba = hyperbolic_distance(&b, &a, -1.0);
423        assert!(
424            (d_ab - d_ba).abs() < 1e-4,
425            "Distance should be symmetric: {} vs {}",
426            d_ab,
427            d_ba
428        );
429    }
430
431    #[test]
432    fn test_hyperbolic_distance_triangle_inequality() {
433        let a = euclidean_to_poincare(&[1.0, 0.0], -1.0);
434        let b = euclidean_to_poincare(&[0.0, 1.0], -1.0);
435        let c = euclidean_to_poincare(&[2.0, 2.0], -1.0);
436
437        let d_ab = hyperbolic_distance(&a, &b, -1.0);
438        let d_bc = hyperbolic_distance(&b, &c, -1.0);
439        let d_ac = hyperbolic_distance(&a, &c, -1.0);
440
441        assert!(
442            d_ac <= d_ab + d_bc + 1e-4,
443            "Triangle inequality: d(a,c)={} should be <= d(a,b)+d(b,c)={}",
444            d_ac,
445            d_ab + d_bc
446        );
447    }
448
449    #[test]
450    fn test_mobius_add_identity() {
451        let a = euclidean_to_poincare(&[0.5, 0.3], -1.0);
452        let zero = vec![0.0, 0.0];
453        let result = mobius_add(&a, &zero, -1.0);
454        for (r, expected) in result.iter().zip(a.iter()) {
455            assert!((r - expected).abs() < 1e-4, "a ⊕ 0 should equal a");
456        }
457    }
458
459    #[test]
460    fn test_mobius_scalar_mul_zero() {
461        let v = euclidean_to_poincare(&[1.0, 2.0], -1.0);
462        let result = mobius_scalar_mul(0.0, &v, -1.0, 1e-5);
463        for r in &result {
464            assert!(r.abs() < 1e-4, "0 ⊗ v should be ~0, got {}", r);
465        }
466    }
467
468    #[test]
469    fn test_mobius_scalar_mul_one() {
470        let v = euclidean_to_poincare(&[1.0, 2.0], -1.0);
471        let result = mobius_scalar_mul(1.0, &v, -1.0, 1e-5);
472        for (r, expected) in result.iter().zip(v.iter()) {
473            assert!((r - expected).abs() < 1e-4, "1 ⊗ v should equal v");
474        }
475    }
476
477    #[test]
478    fn test_hyperbolic_embedding_add_and_search() {
479        let mut he = HyperbolicEmbedding::with_dimensions(3);
480
481        he.add("root", &[0.0, 0.0, 0.0]);
482        he.add("child_a", &[1.0, 0.0, 0.0]);
483        he.add("child_b", &[0.0, 1.0, 0.0]);
484        he.add("grandchild", &[1.0, 1.0, 0.0]);
485
486        assert_eq!(he.len(), 4);
487
488        // Nearest neighbor of child_a should be grandchild (closer in hierarchy)
489        let nn = he.nearest_neighbors("child_a", 2);
490        assert_eq!(nn.len(), 2);
491        // grandchild should be closer to child_a than child_b
492        let gc_dist = nn
493            .iter()
494            .find(|(name, _)| name == "grandchild")
495            .map(|(_, d)| *d);
496        let cb_dist = nn
497            .iter()
498            .find(|(name, _)| name == "child_b")
499            .map(|(_, d)| *d);
500        if let (Some(gc), Some(cb)) = (gc_dist, cb_dist) {
501            assert!(
502                gc < cb,
503                "grandchild should be closer to child_a than child_b"
504            );
505        }
506    }
507
508    #[test]
509    fn test_hyperbolic_embedding_depth() {
510        let mut he = HyperbolicEmbedding::with_dimensions(2);
511
512        he.add("root", &[0.0, 0.0]);
513        he.add("level1", &[0.5, 0.0]);
514        he.add("level2", &[1.0, 0.0]);
515
516        let root_depth = he.depth("root");
517        let l1_depth = he.depth("level1");
518        let l2_depth = he.depth("level2");
519
520        assert!(
521            root_depth < l1_depth,
522            "Root should be shallower: root={}, l1={}",
523            root_depth,
524            l1_depth
525        );
526        assert!(
527            l1_depth < l2_depth,
528            "Level1 should be shallower: l1={}, l2={}",
529            l1_depth,
530            l2_depth
531        );
532    }
533
534    #[test]
535    fn test_rank_by_depth() {
536        let mut he = HyperbolicEmbedding::with_dimensions(2);
537
538        he.add("leaf", &[2.0, 2.0]);
539        he.add("root", &[0.0, 0.0]);
540        he.add("mid", &[0.5, 0.5]);
541
542        let ranked = he.rank_by_depth();
543        assert_eq!(ranked[0].0, "root");
544        assert_eq!(ranked[1].0, "mid");
545        assert_eq!(ranked[2].0, "leaf");
546    }
547
548    #[test]
549    fn test_batch_conversion() {
550        let vectors = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![0.0, 0.0]];
551        let results = batch_euclidean_to_poincare(&vectors, -1.0);
552        assert_eq!(results.len(), 3);
553        // Last should be zero
554        assert_eq!(results[2], vec![0.0, 0.0]);
555    }
556
557    #[test]
558    fn test_curvature_effect() {
559        let v = [1.0, 1.0];
560
561        let p1 = euclidean_to_poincare(&v, -1.0);
562        let p2 = euclidean_to_poincare(&v, -2.0);
563
564        let norm1: f32 = p1.iter().map(|x| x * x).sum::<f32>().sqrt();
565        let norm2: f32 = p2.iter().map(|x| x * x).sum::<f32>().sqrt();
566
567        // Higher curvature magnitude → smaller ball → smaller norm
568        assert!(
569            norm2 < norm1,
570            "Higher curvature should produce smaller ball: {} vs {}",
571            norm2,
572            norm1
573        );
574    }
575
576    #[test]
577    fn test_add_child_hierarchy() {
578        let mut he = HyperbolicEmbedding::with_dimensions(3);
579
580        // Create a simple hierarchy
581        he.add("parent", &[1.0, 0.0, 0.0]);
582        he.add_child("parent", "child", &[0.5, 0.5, 0.0]);
583
584        assert_eq!(he.len(), 2);
585
586        // Child should be farther from origin than parent
587        let parent_depth = he.depth("parent");
588        let child_depth = he.depth("child");
589
590        // Both should exist
591        assert!(he.get("parent").is_some());
592        assert!(he.get("child").is_some());
593    }
594}