mod adapter;
mod config;
mod depth;
mod energy;
pub use adapter::HyperbolicAdapter;
pub use config::HyperbolicCoherenceConfig;
pub use depth::{DepthComputer, HierarchyLevel};
pub use energy::{HyperbolicEnergy, WeightedResidual};
use std::collections::HashMap;
pub type NodeId = u64;
pub type EdgeId = u64;
pub type Result<T> = std::result::Result<T, HyperbolicCoherenceError>;
#[derive(Debug, Clone, thiserror::Error)]
pub enum HyperbolicCoherenceError {
#[error("Node not found: {0}")]
NodeNotFound(NodeId),
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
#[error("Invalid curvature: {0} (must be negative)")]
InvalidCurvature(f32),
#[error("Projection failed: vector norm {0} exceeds ball radius")]
ProjectionFailed(f32),
#[error("HNSW error: {0}")]
HnswError(String),
#[error("Empty collection")]
EmptyCollection,
}
#[derive(Debug)]
pub struct HyperbolicCoherence {
config: HyperbolicCoherenceConfig,
adapter: HyperbolicAdapter,
depth: DepthComputer,
node_states: HashMap<NodeId, Vec<f32>>,
node_depths: HashMap<NodeId, f32>,
}
impl HyperbolicCoherence {
pub fn new(config: HyperbolicCoherenceConfig) -> Self {
let adapter = HyperbolicAdapter::new(config.clone());
let depth = DepthComputer::new(config.curvature);
Self {
config,
adapter,
depth,
node_states: HashMap::new(),
node_depths: HashMap::new(),
}
}
pub fn default_config() -> Self {
Self::new(HyperbolicCoherenceConfig::default())
}
pub fn insert_node(&mut self, node_id: NodeId, state: Vec<f32>) -> Result<()> {
if !self.node_states.is_empty() {
let expected_dim = self.config.dimension;
if state.len() != expected_dim {
return Err(HyperbolicCoherenceError::DimensionMismatch {
expected: expected_dim,
actual: state.len(),
});
}
}
let projected = self.adapter.project_to_ball(&state)?;
let depth = self.depth.compute_depth(&projected);
self.node_depths.insert(node_id, depth);
self.adapter.insert(node_id, projected.clone())?;
self.node_states.insert(node_id, projected);
Ok(())
}
pub fn update_node(&mut self, node_id: NodeId, state: Vec<f32>) -> Result<()> {
if !self.node_states.contains_key(&node_id) {
return Err(HyperbolicCoherenceError::NodeNotFound(node_id));
}
let projected = self.adapter.project_to_ball(&state)?;
let depth = self.depth.compute_depth(&projected);
self.node_depths.insert(node_id, depth);
self.adapter.update(node_id, projected.clone())?;
self.node_states.insert(node_id, projected);
Ok(())
}
pub fn get_node(&self, node_id: NodeId) -> Option<&Vec<f32>> {
self.node_states.get(&node_id)
}
pub fn get_depth(&self, node_id: NodeId) -> Option<f32> {
self.node_depths.get(&node_id).copied()
}
pub fn weighted_edge_energy(
&self,
source_id: NodeId,
target_id: NodeId,
residual: &[f32],
base_weight: f32,
) -> Result<WeightedResidual> {
let source_depth = self
.node_depths
.get(&source_id)
.ok_or(HyperbolicCoherenceError::NodeNotFound(source_id))?;
let target_depth = self
.node_depths
.get(&target_id)
.ok_or(HyperbolicCoherenceError::NodeNotFound(target_id))?;
let avg_depth = (source_depth + target_depth) / 2.0;
let depth_weight = self.config.depth_weight_fn(avg_depth);
let residual_norm_sq: f32 = residual.iter().map(|x| x * x).sum();
let weighted_energy = base_weight * residual_norm_sq * depth_weight;
Ok(WeightedResidual {
source_id,
target_id,
source_depth: *source_depth,
target_depth: *target_depth,
depth_weight,
residual_norm_sq,
base_weight,
weighted_energy,
})
}
pub fn compute_total_energy(
&self,
edges: &[(NodeId, NodeId, Vec<f32>, f32)], ) -> Result<HyperbolicEnergy> {
if edges.is_empty() {
return Ok(HyperbolicEnergy::empty());
}
let mut edge_energies = Vec::with_capacity(edges.len());
let mut total_energy = 0.0f32;
let mut max_depth = 0.0f32;
let mut min_depth = f32::MAX;
for (source, target, residual, weight) in edges {
let weighted = self.weighted_edge_energy(*source, *target, residual, *weight)?;
total_energy += weighted.weighted_energy;
max_depth = max_depth.max(weighted.source_depth.max(weighted.target_depth));
min_depth = min_depth.min(weighted.source_depth.min(weighted.target_depth));
edge_energies.push(weighted);
}
Ok(HyperbolicEnergy {
total_energy,
edge_energies,
curvature: self.config.curvature,
max_depth,
min_depth,
num_edges: edges.len(),
})
}
pub fn find_similar(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> {
let projected = self.adapter.project_to_ball(query)?;
self.adapter.search(&projected, k)
}
pub fn hierarchy_level(&self, node_id: NodeId) -> Result<HierarchyLevel> {
let depth = self
.node_depths
.get(&node_id)
.ok_or(HyperbolicCoherenceError::NodeNotFound(node_id))?;
Ok(self.depth.classify_level(*depth))
}
pub fn frechet_mean(&self, node_ids: &[NodeId]) -> Result<Vec<f32>> {
if node_ids.is_empty() {
return Err(HyperbolicCoherenceError::EmptyCollection);
}
let states: Vec<&Vec<f32>> = node_ids
.iter()
.filter_map(|id| self.node_states.get(id))
.collect();
if states.is_empty() {
return Err(HyperbolicCoherenceError::EmptyCollection);
}
self.adapter.frechet_mean(&states)
}
pub fn config(&self) -> &HyperbolicCoherenceConfig {
&self.config
}
pub fn num_nodes(&self) -> usize {
self.node_states.len()
}
pub fn curvature(&self) -> f32 {
self.config.curvature
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_coherence() {
let config = HyperbolicCoherenceConfig {
dimension: 4,
curvature: -1.0,
..Default::default()
};
let mut coherence = HyperbolicCoherence::new(config);
coherence.insert_node(1, vec![0.1, 0.1, 0.1, 0.1]).unwrap();
coherence.insert_node(2, vec![0.2, 0.2, 0.2, 0.2]).unwrap();
coherence.insert_node(3, vec![0.5, 0.5, 0.5, 0.5]).unwrap();
assert_eq!(coherence.num_nodes(), 3);
let depth1 = coherence.get_depth(1).unwrap();
let depth3 = coherence.get_depth(3).unwrap();
assert!(depth3 > depth1);
}
#[test]
fn test_weighted_energy() {
let config = HyperbolicCoherenceConfig {
dimension: 4,
curvature: -1.0,
..Default::default()
};
let mut coherence = HyperbolicCoherence::new(config);
coherence.insert_node(1, vec![0.1, 0.1, 0.1, 0.1]).unwrap();
coherence.insert_node(2, vec![0.5, 0.5, 0.5, 0.5]).unwrap();
let residual = vec![0.1, 0.1, 0.1, 0.1];
let weighted = coherence.weighted_edge_energy(1, 2, &residual, 1.0).unwrap();
assert!(weighted.weighted_energy > 0.0);
assert!(weighted.depth_weight > 1.0); }
#[test]
fn test_hierarchy_levels() {
let config = HyperbolicCoherenceConfig {
dimension: 4,
curvature: -1.0,
..Default::default()
};
let mut coherence = HyperbolicCoherence::new(config);
coherence.insert_node(1, vec![0.05, 0.05, 0.05, 0.05]).unwrap();
coherence.insert_node(2, vec![0.7, 0.7, 0.0, 0.0]).unwrap();
let level1 = coherence.hierarchy_level(1).unwrap();
let level2 = coherence.hierarchy_level(2).unwrap();
assert!(matches!(level1, HierarchyLevel::Root | HierarchyLevel::High));
assert!(matches!(
level2,
HierarchyLevel::Deep | HierarchyLevel::VeryDeep
));
}
}