use crate::error::{MathError, Result};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CurvatureType {
Euclidean,
Hyperbolic {
curvature: f64,
},
Spherical {
curvature: f64,
},
}
impl CurvatureType {
pub fn hyperbolic() -> Self {
Self::Hyperbolic { curvature: -1.0 }
}
pub fn hyperbolic_with(curvature: f64) -> Self {
Self::Hyperbolic {
curvature: curvature.min(-1e-6),
}
}
pub fn spherical() -> Self {
Self::Spherical { curvature: 1.0 }
}
pub fn spherical_with(curvature: f64) -> Self {
Self::Spherical {
curvature: curvature.max(1e-6),
}
}
pub fn curvature(&self) -> f64 {
match self {
Self::Euclidean => 0.0,
Self::Hyperbolic { curvature } => *curvature,
Self::Spherical { curvature } => *curvature,
}
}
}
#[derive(Debug, Clone)]
pub struct ProductManifoldConfig {
pub euclidean_dim: usize,
pub hyperbolic_dim: usize,
pub hyperbolic_curvature: f64,
pub spherical_dim: usize,
pub spherical_curvature: f64,
pub component_weights: (f64, f64, f64),
}
impl ProductManifoldConfig {
pub fn new(euclidean_dim: usize, hyperbolic_dim: usize, spherical_dim: usize) -> Self {
Self {
euclidean_dim,
hyperbolic_dim,
hyperbolic_curvature: -1.0,
spherical_dim,
spherical_curvature: 1.0,
component_weights: (1.0, 1.0, 1.0),
}
}
pub fn euclidean(dim: usize) -> Self {
Self::new(dim, 0, 0)
}
pub fn hyperbolic(dim: usize) -> Self {
Self::new(0, dim, 0)
}
pub fn spherical(dim: usize) -> Self {
Self::new(0, 0, dim)
}
pub fn euclidean_hyperbolic(euclidean_dim: usize, hyperbolic_dim: usize) -> Self {
Self::new(euclidean_dim, hyperbolic_dim, 0)
}
pub fn with_hyperbolic_curvature(mut self, c: f64) -> Self {
self.hyperbolic_curvature = c.min(-1e-6);
self
}
pub fn with_spherical_curvature(mut self, c: f64) -> Self {
self.spherical_curvature = c.max(1e-6);
self
}
pub fn with_weights(mut self, euclidean: f64, hyperbolic: f64, spherical: f64) -> Self {
self.component_weights = (euclidean.max(0.0), hyperbolic.max(0.0), spherical.max(0.0));
self
}
pub fn total_dim(&self) -> usize {
self.euclidean_dim + self.hyperbolic_dim + self.spherical_dim
}
pub fn validate(&self) -> Result<()> {
if self.total_dim() == 0 {
return Err(MathError::invalid_parameter(
"dimensions",
"at least one component must have non-zero dimension",
));
}
if self.hyperbolic_curvature >= 0.0 {
return Err(MathError::invalid_parameter(
"hyperbolic_curvature",
"must be negative",
));
}
if self.spherical_curvature <= 0.0 {
return Err(MathError::invalid_parameter(
"spherical_curvature",
"must be positive",
));
}
Ok(())
}
pub fn component_ranges(
&self,
) -> (
std::ops::Range<usize>,
std::ops::Range<usize>,
std::ops::Range<usize>,
) {
let e_end = self.euclidean_dim;
let h_end = e_end + self.hyperbolic_dim;
let s_end = h_end + self.spherical_dim;
(0..e_end, e_end..h_end, h_end..s_end)
}
}
impl Default for ProductManifoldConfig {
fn default() -> Self {
Self::new(64, 16, 8)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_creation() {
let config = ProductManifoldConfig::new(32, 16, 8);
assert_eq!(config.euclidean_dim, 32);
assert_eq!(config.hyperbolic_dim, 16);
assert_eq!(config.spherical_dim, 8);
assert_eq!(config.total_dim(), 56);
}
#[test]
fn test_component_ranges() {
let config = ProductManifoldConfig::new(10, 5, 3);
let (e, h, s) = config.component_ranges();
assert_eq!(e, 0..10);
assert_eq!(h, 10..15);
assert_eq!(s, 15..18);
}
#[test]
fn test_validation() {
let config = ProductManifoldConfig::new(0, 0, 0);
assert!(config.validate().is_err());
let config = ProductManifoldConfig::new(10, 5, 0);
assert!(config.validate().is_ok());
}
}