prime-radiant 0.1.0

Universal coherence engine using sheaf Laplacian mathematics for AI safety, hallucination detection, and structural consistency verification in LLMs and distributed systems
//! Hyperbolic Coherence Configuration
//!
//! Configuration for hyperbolic coherence computation.

use serde::{Deserialize, Serialize};

/// Configuration for hyperbolic coherence computation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HyperbolicCoherenceConfig {
    /// State vector dimension
    pub dimension: usize,

    /// Curvature of the hyperbolic space (must be negative)
    /// Typical values: -1.0 (unit curvature), -0.5 (flatter), -2.0 (more curved)
    pub curvature: f32,

    /// Epsilon for numerical stability (projection boundary)
    pub epsilon: f32,

    /// Maximum number of iterations for Frechet mean computation
    pub frechet_max_iters: usize,

    /// Convergence threshold for Frechet mean
    pub frechet_tolerance: f32,

    /// Depth weight function type
    pub depth_weight_type: DepthWeightType,

    /// HNSW M parameter (max connections per node)
    pub hnsw_m: usize,

    /// HNSW ef_construction parameter
    pub hnsw_ef_construction: usize,

    /// Enable sharding for large collections
    pub enable_sharding: bool,

    /// Default shard curvature
    pub default_shard_curvature: f32,
}

impl Default for HyperbolicCoherenceConfig {
    fn default() -> Self {
        Self {
            dimension: 64,
            curvature: -1.0,
            epsilon: 1e-5,
            frechet_max_iters: 100,
            frechet_tolerance: 1e-6,
            depth_weight_type: DepthWeightType::Logarithmic,
            hnsw_m: 16,
            hnsw_ef_construction: 200,
            enable_sharding: false,
            default_shard_curvature: -1.0,
        }
    }
}

impl HyperbolicCoherenceConfig {
    /// Create a configuration for small collections (< 10K nodes)
    pub fn small() -> Self {
        Self {
            dimension: 64,
            curvature: -1.0,
            hnsw_m: 8,
            hnsw_ef_construction: 100,
            enable_sharding: false,
            ..Default::default()
        }
    }

    /// Create a configuration for large collections (> 100K nodes)
    pub fn large() -> Self {
        Self {
            dimension: 64,
            curvature: -1.0,
            hnsw_m: 32,
            hnsw_ef_construction: 400,
            enable_sharding: true,
            ..Default::default()
        }
    }

    /// Validate configuration
    pub fn validate(&self) -> Result<(), String> {
        if self.curvature >= 0.0 {
            return Err(format!(
                "Curvature must be negative, got {}",
                self.curvature
            ));
        }
        if self.dimension == 0 {
            return Err("Dimension must be positive".to_string());
        }
        if self.epsilon <= 0.0 {
            return Err("Epsilon must be positive".to_string());
        }
        Ok(())
    }

    /// Compute depth weight using configured function type
    pub fn depth_weight_fn(&self, depth: f32) -> f32 {
        self.depth_weight_type.compute(depth)
    }
}

/// Type of depth weighting function
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DepthWeightType {
    /// Constant weight (no depth scaling)
    Constant,
    /// Linear: 1 + depth
    Linear,
    /// Logarithmic: 1 + ln(max(depth, 1))
    Logarithmic,
    /// Quadratic: 1 + depth^2
    Quadratic,
    /// Exponential: e^(depth * scale)
    Exponential,
}

impl Default for DepthWeightType {
    fn default() -> Self {
        Self::Logarithmic
    }
}

impl DepthWeightType {
    /// Compute depth weight
    pub fn compute(&self, depth: f32) -> f32 {
        match self {
            Self::Constant => 1.0,
            Self::Linear => 1.0 + depth,
            Self::Logarithmic => 1.0 + depth.max(1.0).ln(),
            Self::Quadratic => 1.0 + depth * depth,
            Self::Exponential => (depth * 0.5).exp().min(10.0), // Capped at 10x
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_default_config() {
        let config = HyperbolicCoherenceConfig::default();
        assert_eq!(config.curvature, -1.0);
        assert!(config.validate().is_ok());
    }

    #[test]
    fn test_invalid_curvature() {
        let config = HyperbolicCoherenceConfig {
            curvature: 1.0, // Invalid - must be negative
            ..Default::default()
        };
        assert!(config.validate().is_err());
    }

    #[test]
    fn test_depth_weights() {
        assert_eq!(DepthWeightType::Constant.compute(5.0), 1.0);
        assert_eq!(DepthWeightType::Linear.compute(5.0), 6.0);

        let log_weight = DepthWeightType::Logarithmic.compute(2.718281828);
        assert!((log_weight - 2.0).abs() < 0.01);
    }
}