use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HyperbolicCoherenceConfig {
pub dimension: usize,
pub curvature: f32,
pub epsilon: f32,
pub frechet_max_iters: usize,
pub frechet_tolerance: f32,
pub depth_weight_type: DepthWeightType,
pub hnsw_m: usize,
pub hnsw_ef_construction: usize,
pub enable_sharding: bool,
pub default_shard_curvature: f32,
}
impl Default for HyperbolicCoherenceConfig {
fn default() -> Self {
Self {
dimension: 64,
curvature: -1.0,
epsilon: 1e-5,
frechet_max_iters: 100,
frechet_tolerance: 1e-6,
depth_weight_type: DepthWeightType::Logarithmic,
hnsw_m: 16,
hnsw_ef_construction: 200,
enable_sharding: false,
default_shard_curvature: -1.0,
}
}
}
impl HyperbolicCoherenceConfig {
pub fn small() -> Self {
Self {
dimension: 64,
curvature: -1.0,
hnsw_m: 8,
hnsw_ef_construction: 100,
enable_sharding: false,
..Default::default()
}
}
pub fn large() -> Self {
Self {
dimension: 64,
curvature: -1.0,
hnsw_m: 32,
hnsw_ef_construction: 400,
enable_sharding: true,
..Default::default()
}
}
pub fn validate(&self) -> Result<(), String> {
if self.curvature >= 0.0 {
return Err(format!(
"Curvature must be negative, got {}",
self.curvature
));
}
if self.dimension == 0 {
return Err("Dimension must be positive".to_string());
}
if self.epsilon <= 0.0 {
return Err("Epsilon must be positive".to_string());
}
Ok(())
}
pub fn depth_weight_fn(&self, depth: f32) -> f32 {
self.depth_weight_type.compute(depth)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DepthWeightType {
Constant,
Linear,
Logarithmic,
Quadratic,
Exponential,
}
impl Default for DepthWeightType {
fn default() -> Self {
Self::Logarithmic
}
}
impl DepthWeightType {
pub fn compute(&self, depth: f32) -> f32 {
match self {
Self::Constant => 1.0,
Self::Linear => 1.0 + depth,
Self::Logarithmic => 1.0 + depth.max(1.0).ln(),
Self::Quadratic => 1.0 + depth * depth,
Self::Exponential => (depth * 0.5).exp().min(10.0), }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = HyperbolicCoherenceConfig::default();
assert_eq!(config.curvature, -1.0);
assert!(config.validate().is_ok());
}
#[test]
fn test_invalid_curvature() {
let config = HyperbolicCoherenceConfig {
curvature: 1.0, ..Default::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_depth_weights() {
assert_eq!(DepthWeightType::Constant.compute(5.0), 1.0);
assert_eq!(DepthWeightType::Linear.compute(5.0), 6.0);
let log_weight = DepthWeightType::Logarithmic.compute(2.718281828);
assert!((log_weight - 2.0).abs() < 0.01);
}
}