#[cfg(feature = "manifold")]
use ruvector_attention::{
ScaledDotProductAttention, HyperbolicAttention, HyperbolicAttentionConfig,
Attention,
};
#[cfg(feature = "manifold")]
use ruvector_verified::{
ProofEnvironment, ProofAttestation,
prove_dim_eq,
proof_store::create_attestation,
gated::{route_proof, ProofKind},
};
#[cfg(feature = "manifold")]
use crate::config::ManifoldConfig;
#[cfg(feature = "manifold")]
use crate::error::{GraphTransformerError, Result};
#[cfg(feature = "manifold")]
const EPS: f32 = 1e-7;
#[cfg(feature = "manifold")]
#[inline]
fn norm_sq(v: &[f32]) -> f32 {
v.iter().map(|&x| x * x).sum()
}
#[cfg(feature = "manifold")]
#[inline]
fn norm(v: &[f32]) -> f32 {
norm_sq(v).sqrt()
}
#[cfg(feature = "manifold")]
#[inline]
fn dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
#[cfg(feature = "manifold")]
#[derive(Debug, Clone, PartialEq)]
pub enum ManifoldType {
PoincareBall { curvature: f32 },
Lorentz { curvature: f32 },
Sphere,
Product(Vec<ManifoldType>),
}
#[cfg(feature = "manifold")]
pub struct ProductManifoldAttention {
config: ManifoldConfig,
spherical_attention: ScaledDotProductAttention,
hyperbolic_attention: HyperbolicAttention,
euclidean_attention: ScaledDotProductAttention,
total_dim: usize,
beta_s: f32,
beta_h: f32,
beta_e: f32,
env: ProofEnvironment,
}
#[cfg(feature = "manifold")]
#[derive(Debug)]
pub struct ManifoldAttentionResult {
pub output: Vec<f32>,
pub curvatures: ManifoldCurvatures,
pub attestation: Option<ProofAttestation>,
}
#[cfg(feature = "manifold")]
#[derive(Debug, Clone)]
pub struct ManifoldCurvatures {
pub spherical: f32,
pub hyperbolic: f32,
pub euclidean: f32,
}
#[cfg(feature = "manifold")]
impl ProductManifoldAttention {
pub fn new(config: ManifoldConfig) -> Self {
let total_dim = config.spherical_dim + config.hyperbolic_dim + config.euclidean_dim;
let spherical_attention = ScaledDotProductAttention::new(config.spherical_dim);
let hyperbolic_config = HyperbolicAttentionConfig {
dim: config.hyperbolic_dim,
curvature: config.curvature,
adaptive_curvature: false,
temperature: 1.0,
frechet_max_iter: 100,
frechet_tol: 1e-6,
};
let hyperbolic_attention = HyperbolicAttention::new(hyperbolic_config);
let euclidean_attention = ScaledDotProductAttention::new(config.euclidean_dim);
Self {
config,
spherical_attention,
hyperbolic_attention,
euclidean_attention,
total_dim,
beta_s: 1.0,
beta_h: 1.0,
beta_e: 1.0,
env: ProofEnvironment::new(),
}
}
pub fn with_betas(config: ManifoldConfig, beta_s: f32, beta_h: f32, beta_e: f32) -> Self {
let mut attn = Self::new(config);
attn.beta_s = beta_s;
attn.beta_h = beta_h;
attn.beta_e = beta_e;
attn
}
fn verify_curvature_compatibility(
&mut self,
_q_s: &[f32],
q_h: &[f32],
) -> Result<ProofAttestation> {
let c = self.config.curvature.abs();
if c < EPS {
return Err(GraphTransformerError::InvariantViolation(
"hyperbolic curvature must be non-zero".into(),
));
}
let norm_h_sq = norm_sq(q_h);
if norm_h_sq >= 1.0 / c {
}
let decision = route_proof(ProofKind::Reflexivity, &self.env);
let dim_tag = self.total_dim as u32;
let proof_id = ruvector_verified::gated::verify_tiered(
&mut self.env,
dim_tag,
dim_tag,
decision.tier,
)?;
Ok(create_attestation(&self.env, proof_id))
}
pub fn compute(
&mut self,
query: &[f32],
keys: &[Vec<f32>],
values: &[Vec<f32>],
) -> Result<ManifoldAttentionResult> {
if query.len() != self.total_dim {
return Err(GraphTransformerError::DimensionMismatch {
expected: self.total_dim,
actual: query.len(),
});
}
let s_dim = self.config.spherical_dim;
let h_dim = self.config.hyperbolic_dim;
let q_s = &query[..s_dim];
let q_h = &query[s_dim..s_dim + h_dim];
let q_e = &query[s_dim + h_dim..];
let attestation = self.verify_curvature_compatibility(q_s, q_h).ok();
let k_s: Vec<&[f32]> = keys.iter().map(|k| &k[..s_dim]).collect();
let k_h: Vec<&[f32]> = keys.iter().map(|k| &k[s_dim..s_dim + h_dim]).collect();
let k_e: Vec<&[f32]> = keys.iter().map(|k| &k[s_dim + h_dim..]).collect();
let v_s: Vec<&[f32]> = values.iter().map(|v| &v[..s_dim]).collect();
let v_h: Vec<&[f32]> = values.iter().map(|v| &v[s_dim..s_dim + h_dim]).collect();
let v_e: Vec<&[f32]> = values.iter().map(|v| &v[s_dim + h_dim..]).collect();
let q_s_proj = project_to_sphere(q_s);
let k_s_proj: Vec<Vec<f32>> = k_s.iter().map(|k| project_to_sphere(k)).collect();
let k_s_refs: Vec<&[f32]> = k_s_proj.iter().map(|k| k.as_slice()).collect();
let out_s = self.spherical_attention.compute(&q_s_proj, &k_s_refs, &v_s)
.map_err(GraphTransformerError::Attention)?;
let out_h = self.hyperbolic_attention.compute(q_h, &k_h, &v_h)
.map_err(GraphTransformerError::Attention)?;
let out_e = self.euclidean_attention.compute(q_e, &k_e, &v_e)
.map_err(GraphTransformerError::Attention)?;
let beta_sum = self.beta_s + self.beta_h + self.beta_e;
let w_s = self.beta_s / beta_sum;
let w_h = self.beta_h / beta_sum;
let w_e = self.beta_e / beta_sum;
let mut output = Vec::with_capacity(self.total_dim);
output.extend(out_s.iter().map(|&x| w_s * x));
output.extend(out_h.iter().map(|&x| w_h * x));
output.extend(out_e.iter().map(|&x| w_e * x));
let curvatures = ManifoldCurvatures {
spherical: 1.0,
hyperbolic: self.config.curvature,
euclidean: 0.0,
};
Ok(ManifoldAttentionResult { output, curvatures, attestation })
}
pub fn total_dim(&self) -> usize {
self.total_dim
}
pub fn config(&self) -> &ManifoldConfig {
&self.config
}
pub fn manifold_type(&self) -> ManifoldType {
ManifoldType::Product(vec![
ManifoldType::Sphere,
ManifoldType::PoincareBall { curvature: self.config.curvature.abs() },
ManifoldType::PoincareBall { curvature: 0.0 }, ])
}
}
#[cfg(feature = "manifold")]
pub struct CurvatureAdaptiveRouter {
neg_threshold: f32,
pos_threshold: f32,
temperature: f32,
}
#[cfg(feature = "manifold")]
#[derive(Debug, Clone)]
pub struct RoutingWeights {
pub spherical: f32,
pub hyperbolic: f32,
pub euclidean: f32,
}
#[cfg(feature = "manifold")]
impl CurvatureAdaptiveRouter {
pub fn new() -> Self {
Self {
neg_threshold: -0.1,
pos_threshold: 0.1,
temperature: 5.0,
}
}
pub fn with_params(neg_threshold: f32, pos_threshold: f32, temperature: f32) -> Self {
Self {
neg_threshold,
pos_threshold,
temperature,
}
}
pub fn route(&self, ollivier_ricci_curvature: f32) -> RoutingWeights {
let kappa = ollivier_ricci_curvature;
let w_hyp = sigmoid(self.temperature * (self.neg_threshold - kappa));
let w_sph = sigmoid(self.temperature * (kappa - self.pos_threshold));
let w_euc = (-self.temperature * kappa * kappa / 2.0).exp();
let total = w_hyp + w_sph + w_euc;
RoutingWeights {
hyperbolic: w_hyp / total,
spherical: w_sph / total,
euclidean: w_euc / total,
}
}
pub fn route_batch(&self, curvatures: &[f32]) -> Vec<RoutingWeights> {
curvatures.iter().map(|&k| self.route(k)).collect()
}
pub fn estimate_ollivier_ricci(
&self,
node_i_features: &[f32],
node_j_features: &[f32],
neighbors_i: &[&[f32]],
neighbors_j: &[&[f32]],
) -> f32 {
let d_ij = euclidean_distance(node_i_features, node_j_features);
if d_ij < EPS {
return 1.0; }
let centroid_i = compute_centroid(neighbors_i);
let centroid_j = compute_centroid(neighbors_j);
let w1_approx = euclidean_distance(¢roid_i, ¢roid_j);
1.0 - w1_approx / d_ij
}
}
#[cfg(feature = "manifold")]
impl Default for CurvatureAdaptiveRouter {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "manifold")]
pub struct GeodesicMessagePassing {
manifold: ManifoldType,
frechet_max_iter: usize,
frechet_tol: f32,
env: ProofEnvironment,
}
#[cfg(feature = "manifold")]
#[derive(Debug)]
pub struct MessagePassingResult {
pub node_messages: Vec<Vec<f32>>,
pub norms_preserved: bool,
pub attestation: Option<ProofAttestation>,
}
#[cfg(feature = "manifold")]
impl GeodesicMessagePassing {
pub fn new(manifold: ManifoldType) -> Self {
let curvature = match &manifold {
ManifoldType::PoincareBall { curvature } => *curvature,
_ => 1.0,
};
let max_iter = if curvature > 5.0 { 200 } else { 100 };
Self {
manifold,
frechet_max_iter: max_iter,
frechet_tol: 1e-6,
env: ProofEnvironment::new(),
}
}
pub fn with_frechet_params(
manifold: ManifoldType,
max_iter: usize,
tol: f32,
) -> Self {
Self {
manifold,
frechet_max_iter: max_iter,
frechet_tol: tol,
env: ProofEnvironment::new(),
}
}
pub fn parallel_transport_poincare(
&self,
v: &[f32],
from: &[f32],
to: &[f32],
c: f32,
) -> Vec<f32> {
let c = c.abs().max(EPS);
let lambda_from = conformal_factor(from, c);
let lambda_to = conformal_factor(to, c);
let b_plus_v = mobius_add_internal(to, v, c);
let a_plus_bv = mobius_add_internal(from, &b_plus_v, c);
let a_plus_b = mobius_add_internal(from, to, c);
let neg_ab: Vec<f32> = a_plus_b.iter().map(|&x| -x).collect();
let gyrated = mobius_add_internal(&neg_ab, &a_plus_bv, c);
let scale = lambda_from / lambda_to.max(EPS);
gyrated.iter().map(|&x| x * scale).collect()
}
pub fn parallel_transport_sphere(
&self,
v: &[f32],
from: &[f32],
to: &[f32],
) -> Vec<f32> {
let d = dot(from, to).clamp(-1.0, 1.0);
let angle = d.acos();
if angle.abs() < EPS {
return v.to_vec();
}
let sum: Vec<f32> = from.iter().zip(to.iter()).map(|(&a, &b)| a + b).collect();
let dot_sv = dot(&sum, v);
let coeff = dot_sv / (1.0 + d).max(EPS);
v.iter().zip(sum.iter()).map(|(&vi, &si)| vi - coeff * si).collect()
}
pub fn propagate(
&mut self,
node_features: &[Vec<f32>],
edges: &[(usize, usize)],
) -> Result<MessagePassingResult> {
let n = node_features.len();
let dim = if n > 0 { node_features[0].len() } else { 0 };
let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
for &(u, v) in edges {
if u < n && v < n {
adj[u].push(v);
adj[v].push(u);
}
}
let mut node_messages = Vec::with_capacity(n);
let mut all_norms_preserved = true;
for i in 0..n {
if adj[i].is_empty() {
node_messages.push(node_features[i].clone());
continue;
}
let mut transported: Vec<Vec<f32>> = Vec::with_capacity(adj[i].len());
for &j in &adj[i] {
let msg = match &self.manifold {
ManifoldType::PoincareBall { curvature } => {
self.parallel_transport_poincare(
&node_features[j],
&node_features[j],
&node_features[i],
*curvature,
)
}
ManifoldType::Sphere => {
let from_proj = project_to_sphere(&node_features[j]);
let to_proj = project_to_sphere(&node_features[i]);
self.parallel_transport_sphere(
&node_features[j],
&from_proj,
&to_proj,
)
}
_ => {
node_features[j].clone()
}
};
let orig_norm = norm(&node_features[j]);
let trans_norm = norm(&msg);
if orig_norm > EPS && (trans_norm / orig_norm - 1.0).abs() > 0.1 {
all_norms_preserved = false;
}
transported.push(msg);
}
let aggregated = match &self.manifold {
ManifoldType::PoincareBall { curvature } => {
let refs: Vec<&[f32]> = transported.iter().map(|t| t.as_slice()).collect();
ruvector_attention::hyperbolic::frechet_mean(
&refs,
None,
*curvature,
self.frechet_max_iter,
self.frechet_tol,
)
}
ManifoldType::Sphere => {
spherical_frechet_mean(&transported, self.frechet_max_iter, self.frechet_tol)
}
_ => {
euclidean_mean(&transported)
}
};
node_messages.push(aggregated);
}
let attestation = if all_norms_preserved {
let dim_tag = dim as u32;
let proof_id = prove_dim_eq(&mut self.env, dim_tag, dim_tag)?;
Some(create_attestation(&self.env, proof_id))
} else {
None
};
Ok(MessagePassingResult {
node_messages,
norms_preserved: all_norms_preserved,
attestation,
})
}
}
#[cfg(feature = "manifold")]
pub struct RiemannianAdamOptimizer {
lr: f32,
beta1: f32,
beta2: f32,
adam_eps: f32,
m: Vec<f32>,
v: Vec<f32>,
t: u32,
manifold: ManifoldType,
env: ProofEnvironment,
}
#[cfg(feature = "manifold")]
#[derive(Debug)]
pub struct OptimizerStepResult {
pub params: Vec<f32>,
pub on_manifold: bool,
pub attestation: Option<ProofAttestation>,
}
#[cfg(feature = "manifold")]
impl RiemannianAdamOptimizer {
pub fn new(dim: usize, manifold: ManifoldType) -> Self {
Self {
lr: 0.001,
beta1: 0.9,
beta2: 0.999,
adam_eps: 1e-8,
m: vec![0.0; dim],
v: vec![0.0; dim],
t: 0,
manifold,
env: ProofEnvironment::new(),
}
}
pub fn with_params(
dim: usize,
manifold: ManifoldType,
lr: f32,
beta1: f32,
beta2: f32,
) -> Self {
Self {
lr,
beta1,
beta2,
adam_eps: 1e-8,
m: vec![0.0; dim],
v: vec![0.0; dim],
t: 0,
manifold,
env: ProofEnvironment::new(),
}
}
pub fn step(
&mut self,
params: &[f32],
euclidean_grad: &[f32],
) -> Result<OptimizerStepResult> {
if params.len() != euclidean_grad.len() || params.len() != self.m.len() {
return Err(GraphTransformerError::DimensionMismatch {
expected: self.m.len(),
actual: params.len(),
});
}
self.t += 1;
let dim = params.len();
let riemannian_grad = match &self.manifold {
ManifoldType::PoincareBall { curvature } => {
let c = curvature.abs().max(EPS);
let norm_sq_p = norm_sq(params);
let factor = (1.0 - c * norm_sq_p).max(EPS);
let scale = factor * factor / 4.0;
euclidean_grad.iter().map(|&g| scale * g).collect::<Vec<f32>>()
}
ManifoldType::Sphere => {
let dp = dot(euclidean_grad, params);
euclidean_grad.iter().zip(params.iter())
.map(|(&g, &p)| g - dp * p)
.collect::<Vec<f32>>()
}
_ => euclidean_grad.to_vec(),
};
for i in 0..dim {
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * riemannian_grad[i];
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * riemannian_grad[i] * riemannian_grad[i];
}
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
let update: Vec<f32> = (0..dim)
.map(|i| {
let m_hat = self.m[i] / bc1;
let v_hat = self.v[i] / bc2;
-self.lr * m_hat / (v_hat.sqrt() + self.adam_eps)
})
.collect();
let new_params = match &self.manifold {
ManifoldType::PoincareBall { curvature } => {
let c = curvature.abs().max(EPS);
let exp = poincare_exp_map(&update, params, c);
poincare_project(&exp, c)
}
ManifoldType::Sphere => {
let exp = sphere_exp_map(&update, params);
project_to_sphere(&exp)
}
_ => {
params.iter().zip(update.iter()).map(|(&p, &u)| p + u).collect()
}
};
let on_manifold = self.check_on_manifold(&new_params);
let attestation = if on_manifold {
let dim_tag = dim as u32;
let proof_id = prove_dim_eq(&mut self.env, dim_tag, dim_tag)?;
Some(create_attestation(&self.env, proof_id))
} else {
None
};
Ok(OptimizerStepResult {
params: new_params,
on_manifold,
attestation,
})
}
fn check_on_manifold(&self, params: &[f32]) -> bool {
match &self.manifold {
ManifoldType::PoincareBall { curvature } => {
let c = curvature.abs().max(EPS);
norm_sq(params) < 1.0 / c
}
ManifoldType::Sphere => {
(norm(params) - 1.0).abs() < 0.01
}
_ => true,
}
}
}
#[cfg(feature = "manifold")]
#[derive(Debug, Clone, PartialEq)]
pub enum LieGroupType {
SO3,
SE3,
U1,
}
#[cfg(feature = "manifold")]
pub struct LieGroupEquivariantAttention {
group: LieGroupType,
scalar_dim: usize,
vector_dim: usize,
total_dim: usize,
_env: ProofEnvironment,
}
#[cfg(feature = "manifold")]
#[derive(Debug)]
pub struct EquivariantAttentionResult {
pub output: Vec<f32>,
pub scalar_output: Vec<f32>,
pub vector_output: Vec<f32>,
}
#[cfg(feature = "manifold")]
impl LieGroupEquivariantAttention {
pub fn new(group: LieGroupType, scalar_dim: usize, vector_dim: usize) -> Self {
Self {
group,
scalar_dim,
vector_dim,
total_dim: scalar_dim + vector_dim,
_env: ProofEnvironment::new(),
}
}
pub fn compute(
&self,
query: &[f32],
keys: &[Vec<f32>],
values: &[Vec<f32>],
) -> Result<EquivariantAttentionResult> {
if query.len() != self.total_dim {
return Err(GraphTransformerError::DimensionMismatch {
expected: self.total_dim,
actual: query.len(),
});
}
let sd = self.scalar_dim;
let q_scalar = &query[..sd];
let _q_vector = &query[sd..];
let k_scalars: Vec<&[f32]> = keys.iter().map(|k| &k[..sd]).collect();
let v_scalars: Vec<&[f32]> = values.iter().map(|v| &v[..sd]).collect();
let v_vectors: Vec<&[f32]> = values.iter().map(|v| &v[sd..]).collect();
let weights = self.compute_invariant_weights(q_scalar, &k_scalars);
let scalar_out = weighted_sum(&weights, &v_scalars, sd);
let vec_dim = self.vector_dim;
let vector_out = weighted_sum(&weights, &v_vectors, vec_dim);
let mut output = Vec::with_capacity(self.total_dim);
output.extend_from_slice(&scalar_out);
output.extend_from_slice(&vector_out);
Ok(EquivariantAttentionResult {
output,
scalar_output: scalar_out,
vector_output: vector_out,
})
}
fn compute_invariant_weights(&self, query: &[f32], keys: &[&[f32]]) -> Vec<f32> {
if keys.is_empty() {
return vec![];
}
let scale = (self.scalar_dim as f32).sqrt();
let scores: Vec<f32> = keys.iter()
.map(|k| dot(query, k) / scale)
.collect();
softmax(&scores)
}
pub fn group(&self) -> &LieGroupType {
&self.group
}
pub fn total_dim(&self) -> usize {
self.total_dim
}
}
#[cfg(feature = "manifold")]
fn project_to_sphere(v: &[f32]) -> Vec<f32> {
let n = norm(v);
if n < EPS {
let mut result = vec![0.0; v.len()];
if !result.is_empty() {
result[0] = 1.0;
}
return result;
}
v.iter().map(|&x| x / n).collect()
}
#[cfg(feature = "manifold")]
#[inline]
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
#[cfg(feature = "manifold")]
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(&x, &y)| (x - y).powi(2)).sum::<f32>().sqrt()
}
#[cfg(feature = "manifold")]
fn compute_centroid(points: &[&[f32]]) -> Vec<f32> {
if points.is_empty() {
return vec![];
}
let dim = points[0].len();
let n = points.len() as f32;
let mut centroid = vec![0.0f32; dim];
for p in points {
for (i, &val) in p.iter().enumerate() {
centroid[i] += val;
}
}
for c in &mut centroid {
*c /= n;
}
centroid
}
#[cfg(feature = "manifold")]
fn euclidean_mean(vecs: &[Vec<f32>]) -> Vec<f32> {
if vecs.is_empty() {
return vec![];
}
let dim = vecs[0].len();
let n = vecs.len() as f32;
let mut mean = vec![0.0f32; dim];
for v in vecs {
for (i, &val) in v.iter().enumerate() {
mean[i] += val;
}
}
for m in &mut mean {
*m /= n;
}
mean
}
#[cfg(feature = "manifold")]
fn spherical_frechet_mean(points: &[Vec<f32>], max_iter: usize, tol: f32) -> Vec<f32> {
if points.is_empty() {
return vec![];
}
if points.len() == 1 {
return project_to_sphere(&points[0]);
}
let dim = points[0].len();
let lr = 0.1;
let mut mean = project_to_sphere(&euclidean_mean(points));
for _ in 0..max_iter {
let mut grad = vec![0.0f32; dim];
for p in points {
let p_proj = project_to_sphere(p);
let log = sphere_log_map(&p_proj, &mean);
for (i, &val) in log.iter().enumerate() {
grad[i] += val;
}
}
let grad_norm = norm(&grad);
if grad_norm < tol {
break;
}
let step: Vec<f32> = grad.iter().map(|&g| lr * g / points.len() as f32).collect();
mean = sphere_exp_map(&step, &mean);
mean = project_to_sphere(&mean);
}
mean
}
#[cfg(feature = "manifold")]
fn sphere_log_map(q: &[f32], p: &[f32]) -> Vec<f32> {
let d = dot(p, q).clamp(-1.0, 1.0);
let angle = d.acos();
if angle.abs() < EPS {
return vec![0.0; p.len()];
}
let mut v: Vec<f32> = q.iter().zip(p.iter()).map(|(&qi, &pi)| qi - d * pi).collect();
let v_norm = norm(&v);
if v_norm < EPS {
return vec![0.0; p.len()];
}
for vi in &mut v {
*vi = *vi * angle / v_norm;
}
v
}
#[cfg(feature = "manifold")]
fn sphere_exp_map(v: &[f32], p: &[f32]) -> Vec<f32> {
let v_norm = norm(v);
if v_norm < EPS {
return p.to_vec();
}
let cos_t = v_norm.cos();
let sin_t = v_norm.sin();
p.iter().zip(v.iter())
.map(|(&pi, &vi)| cos_t * pi + sin_t * vi / v_norm)
.collect()
}
#[cfg(feature = "manifold")]
fn mobius_add_internal(u: &[f32], v: &[f32], c: f32) -> Vec<f32> {
let c = c.abs().max(EPS);
let norm_u_sq = norm_sq(u);
let norm_v_sq = norm_sq(v);
let dot_uv: f32 = dot(u, v);
let coef_u = 1.0 + 2.0 * c * dot_uv + c * norm_v_sq;
let coef_v = 1.0 - c * norm_u_sq;
let denom = 1.0 + 2.0 * c * dot_uv + c * c * norm_u_sq * norm_v_sq;
let result: Vec<f32> = u.iter().zip(v.iter())
.map(|(&ui, &vi)| (coef_u * ui + coef_v * vi) / denom.max(EPS))
.collect();
poincare_project(&result, c)
}
#[cfg(feature = "manifold")]
#[inline]
fn conformal_factor(x: &[f32], c: f32) -> f32 {
2.0 / (1.0 - c * norm_sq(x)).max(EPS)
}
#[cfg(feature = "manifold")]
fn poincare_exp_map(v: &[f32], p: &[f32], c: f32) -> Vec<f32> {
let sqrt_c = c.sqrt();
let norm_p_sq = norm_sq(p);
let lambda_p = 2.0 / (1.0 - c * norm_p_sq).max(EPS);
let v_norm = norm(v);
if v_norm < EPS {
return p.to_vec();
}
let arg = (sqrt_c * lambda_p * v_norm / 2.0).tanh();
let coef = arg / (sqrt_c * v_norm);
let transported: Vec<f32> = v.iter().map(|&vi| coef * vi).collect();
mobius_add_internal(p, &transported, c)
}
#[cfg(feature = "manifold")]
fn poincare_project(x: &[f32], c: f32) -> Vec<f32> {
let c = c.abs().max(EPS);
let max_norm = (1.0 / c).sqrt() - EPS;
let x_norm = norm(x);
if x_norm <= max_norm {
x.to_vec()
} else {
let scale = max_norm / x_norm.max(EPS);
x.iter().map(|&xi| scale * xi).collect()
}
}
#[cfg(feature = "manifold")]
fn weighted_sum(weights: &[f32], vecs: &[&[f32]], dim: usize) -> Vec<f32> {
let mut result = vec![0.0f32; dim];
for (&w, v) in weights.iter().zip(vecs.iter()) {
for (i, &val) in v.iter().enumerate() {
if i < dim {
result[i] += w * val;
}
}
}
result
}
#[cfg(feature = "manifold")]
fn softmax(scores: &[f32]) -> Vec<f32> {
if scores.is_empty() {
return vec![];
}
let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp: Vec<f32> = scores.iter().map(|&s| (s - max_s).exp()).collect();
let sum: f32 = exp.iter().sum();
if sum < EPS {
vec![1.0 / scores.len() as f32; scores.len()]
} else {
exp.iter().map(|&e| e / sum).collect()
}
}
#[cfg(feature = "manifold")]
pub fn spherical_geodesic(a: &[f32], b: &[f32]) -> f32 {
let d: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
d.clamp(-1.0, 1.0).acos()
}
#[cfg(feature = "manifold")]
pub fn hyperbolic_geodesic(a: &[f32], b: &[f32], curvature: f32) -> f32 {
let c = curvature.abs();
let diff_sq: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| (x - y).powi(2)).sum();
let norm_a_sq: f32 = a.iter().map(|&x| x * x).sum();
let norm_b_sq: f32 = b.iter().map(|&x| x * x).sum();
let denom = (1.0 - c * norm_a_sq) * (1.0 - c * norm_b_sq);
if denom.abs() < 1e-8 {
return f32::INFINITY;
}
let arg = 1.0 + 2.0 * c * diff_sq / denom;
(1.0 / c.sqrt()) * arg.max(1.0).acosh()
}
#[cfg(test)]
#[cfg(feature = "manifold")]
mod tests {
use super::*;
#[test]
fn test_product_manifold_attention_forward_4node() {
let config = ManifoldConfig {
spherical_dim: 4,
hyperbolic_dim: 4,
euclidean_dim: 4,
curvature: -1.0,
};
let mut attn = ProductManifoldAttention::new(config);
assert_eq!(attn.total_dim(), 12);
let query = vec![0.5; 12];
let keys = vec![
vec![0.3; 12],
vec![0.7; 12],
vec![0.1; 12],
];
let values = vec![
vec![1.0; 12],
vec![2.0; 12],
vec![0.5; 12],
];
let result = attn.compute(&query, &keys, &values);
assert!(result.is_ok(), "compute failed: {:?}", result.err());
let result = result.unwrap();
assert_eq!(result.output.len(), 12);
assert!(result.curvatures.spherical > 0.0);
assert!(result.curvatures.hyperbolic < 0.0);
assert!((result.curvatures.euclidean).abs() < 1e-6);
assert!(result.attestation.is_some());
}
#[test]
fn test_product_manifold_dimension_mismatch() {
let config = ManifoldConfig {
spherical_dim: 4,
hyperbolic_dim: 4,
euclidean_dim: 4,
curvature: -1.0,
};
let mut attn = ProductManifoldAttention::new(config);
let query = vec![0.5; 8]; let result = attn.compute(&query, &[], &[]);
assert!(result.is_err());
}
#[test]
fn test_product_manifold_with_betas() {
let config = ManifoldConfig {
spherical_dim: 4,
hyperbolic_dim: 4,
euclidean_dim: 4,
curvature: -1.0,
};
let mut attn = ProductManifoldAttention::with_betas(config, 2.0, 1.0, 0.5);
let query = vec![0.3; 12];
let keys = vec![vec![0.4; 12], vec![0.6; 12]];
let values = vec![vec![1.0; 12], vec![2.0; 12]];
let result = attn.compute(&query, &keys, &values).unwrap();
assert_eq!(result.output.len(), 12);
}
#[test]
fn test_product_manifold_type() {
let config = ManifoldConfig {
spherical_dim: 4,
hyperbolic_dim: 4,
euclidean_dim: 4,
curvature: -1.0,
};
let attn = ProductManifoldAttention::new(config);
let mt = attn.manifold_type();
match mt {
ManifoldType::Product(components) => {
assert_eq!(components.len(), 3);
assert_eq!(components[0], ManifoldType::Sphere);
}
_ => panic!("expected Product manifold type"),
}
}
#[test]
fn test_router_negative_curvature_routes_hyperbolic() {
let router = CurvatureAdaptiveRouter::new();
let weights = router.route(-0.5);
assert!(
weights.hyperbolic > weights.spherical,
"hyperbolic={} should exceed spherical={} for kappa=-0.5",
weights.hyperbolic,
weights.spherical,
);
assert!(
weights.hyperbolic > weights.euclidean,
"hyperbolic={} should exceed euclidean={} for kappa=-0.5",
weights.hyperbolic,
weights.euclidean,
);
}
#[test]
fn test_router_positive_curvature_routes_spherical() {
let router = CurvatureAdaptiveRouter::new();
let weights = router.route(0.5);
assert!(
weights.spherical > weights.hyperbolic,
"spherical={} should exceed hyperbolic={} for kappa=0.5",
weights.spherical,
weights.hyperbolic,
);
assert!(
weights.spherical > weights.euclidean,
"spherical={} should exceed euclidean={} for kappa=0.5",
weights.spherical,
weights.euclidean,
);
}
#[test]
fn test_router_zero_curvature_routes_euclidean() {
let router = CurvatureAdaptiveRouter::new();
let weights = router.route(0.0);
assert!(
weights.euclidean > weights.hyperbolic,
"euclidean={} should exceed hyperbolic={} for kappa=0.0",
weights.euclidean,
weights.hyperbolic,
);
assert!(
weights.euclidean > weights.spherical,
"euclidean={} should exceed spherical={} for kappa=0.0",
weights.euclidean,
weights.spherical,
);
}
#[test]
fn test_router_weights_sum_to_one() {
let router = CurvatureAdaptiveRouter::new();
for kappa in [-2.0, -0.5, -0.1, 0.0, 0.1, 0.5, 2.0] {
let w = router.route(kappa);
let sum = w.spherical + w.hyperbolic + w.euclidean;
assert!(
(sum - 1.0).abs() < 1e-5,
"weights for kappa={} sum to {} (should be 1.0)",
kappa,
sum,
);
}
}
#[test]
fn test_router_batch() {
let router = CurvatureAdaptiveRouter::new();
let curvatures = vec![-1.0, 0.0, 1.0];
let results = router.route_batch(&curvatures);
assert_eq!(results.len(), 3);
}
#[test]
fn test_router_ollivier_ricci_estimate() {
let router = CurvatureAdaptiveRouter::new();
let a = vec![0.0, 0.0];
let b = vec![1.0, 0.0];
let neighbors_a: Vec<&[f32]> = vec![&[0.1, 0.1], &[-0.1, 0.1]];
let neighbors_b: Vec<&[f32]> = vec![&[0.9, 0.1], &[1.1, -0.1]];
let kappa = router.estimate_ollivier_ricci(&a, &b, &neighbors_a, &neighbors_b);
assert!(kappa.is_finite());
}
#[test]
fn test_geodesic_message_passing_poincare() {
let manifold = ManifoldType::PoincareBall { curvature: 1.0 };
let mut gmp = GeodesicMessagePassing::new(manifold);
let features = vec![
vec![0.1, 0.2],
vec![0.3, 0.1],
vec![-0.1, 0.3],
];
let edges = vec![(0, 1), (1, 2), (0, 2)];
let result = gmp.propagate(&features, &edges);
assert!(result.is_ok(), "propagate failed: {:?}", result.err());
let result = result.unwrap();
assert_eq!(result.node_messages.len(), 3);
for msg in &result.node_messages {
assert_eq!(msg.len(), 2);
}
}
#[test]
fn test_geodesic_transport_norm_preservation() {
let manifold = ManifoldType::PoincareBall { curvature: 1.0 };
let gmp = GeodesicMessagePassing::new(manifold);
let v = vec![0.1, 0.05];
let from = vec![0.2, 0.1];
let to = vec![0.3, -0.1];
let transported = gmp.parallel_transport_poincare(&v, &from, &to, 1.0);
let orig_norm = norm(&v);
let trans_norm = norm(&transported);
assert!(
(trans_norm / orig_norm - 1.0).abs() < 0.5,
"norm ratio {}/{} = {} deviates too far from 1.0",
trans_norm,
orig_norm,
trans_norm / orig_norm,
);
}
#[test]
fn test_geodesic_message_passing_sphere() {
let manifold = ManifoldType::Sphere;
let mut gmp = GeodesicMessagePassing::new(manifold);
let features = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let edges = vec![(0, 1), (1, 2)];
let result = gmp.propagate(&features, &edges).unwrap();
assert_eq!(result.node_messages.len(), 3);
}
#[test]
fn test_geodesic_message_passing_euclidean() {
let manifold = ManifoldType::Lorentz { curvature: 1.0 }; let mut gmp = GeodesicMessagePassing::new(manifold);
let features = vec![
vec![1.0, 2.0],
vec![3.0, 4.0],
];
let edges = vec![(0, 1)];
let result = gmp.propagate(&features, &edges).unwrap();
assert_eq!(result.node_messages.len(), 2);
}
#[test]
fn test_riemannian_adam_poincare_stays_on_manifold() {
let manifold = ManifoldType::PoincareBall { curvature: 1.0 };
let mut opt = RiemannianAdamOptimizer::new(3, manifold);
let mut params = vec![0.1, 0.2, -0.1];
let grad = vec![0.5, -0.3, 0.1];
for _ in 0..10 {
let result = opt.step(¶ms, &grad).unwrap();
params = result.params.clone();
let nsq = norm_sq(¶ms);
assert!(
nsq < 1.0,
"params norm^2 = {} >= 1.0, left the Poincare ball",
nsq,
);
assert!(result.on_manifold);
assert!(result.attestation.is_some());
}
}
#[test]
fn test_riemannian_adam_sphere_stays_on_manifold() {
let manifold = ManifoldType::Sphere;
let mut opt = RiemannianAdamOptimizer::new(3, manifold);
let mut params = project_to_sphere(&[0.5, 0.5, 0.5]);
let grad = vec![0.1, -0.2, 0.05];
for _ in 0..10 {
let result = opt.step(¶ms, &grad).unwrap();
params = result.params.clone();
let n = norm(¶ms);
assert!(
(n - 1.0).abs() < 0.02,
"params norm = {} deviates from 1.0",
n,
);
assert!(result.on_manifold);
}
}
#[test]
fn test_riemannian_adam_euclidean() {
let manifold = ManifoldType::Lorentz { curvature: 1.0 };
let mut opt = RiemannianAdamOptimizer::new(2, manifold);
let params = vec![1.0, 2.0];
let grad = vec![0.1, 0.2];
let result = opt.step(¶ms, &grad).unwrap();
assert_eq!(result.params.len(), 2);
assert!(result.on_manifold); }
#[test]
fn test_riemannian_adam_dimension_mismatch() {
let manifold = ManifoldType::PoincareBall { curvature: 1.0 };
let mut opt = RiemannianAdamOptimizer::new(3, manifold);
let params = vec![0.1, 0.2]; let grad = vec![0.1, 0.2];
let result = opt.step(¶ms, &grad);
assert!(result.is_err());
}
#[test]
fn test_lie_group_equivariant_forward_so3() {
let attn = LieGroupEquivariantAttention::new(LieGroupType::SO3, 4, 6);
assert_eq!(attn.total_dim(), 10);
assert_eq!(*attn.group(), LieGroupType::SO3);
let query = vec![0.5; 10];
let keys = vec![vec![0.3; 10], vec![0.7; 10]];
let values = vec![vec![1.0; 10], vec![2.0; 10]];
let result = attn.compute(&query, &keys, &values);
assert!(result.is_ok(), "compute failed: {:?}", result.err());
let result = result.unwrap();
assert_eq!(result.output.len(), 10);
assert_eq!(result.scalar_output.len(), 4);
assert_eq!(result.vector_output.len(), 6);
}
#[test]
fn test_lie_group_equivariant_forward_se3() {
let attn = LieGroupEquivariantAttention::new(LieGroupType::SE3, 8, 12);
let query = vec![0.2; 20];
let keys = vec![vec![0.4; 20], vec![0.6; 20], vec![0.1; 20]];
let values = vec![vec![1.0; 20], vec![2.0; 20], vec![0.5; 20]];
let result = attn.compute(&query, &keys, &values).unwrap();
assert_eq!(result.output.len(), 20);
}
#[test]
fn test_lie_group_equivariant_forward_u1() {
let attn = LieGroupEquivariantAttention::new(LieGroupType::U1, 3, 3);
let query = vec![0.5; 6];
let keys = vec![vec![0.3; 6]];
let values = vec![vec![1.0; 6]];
let result = attn.compute(&query, &keys, &values).unwrap();
assert_eq!(result.output.len(), 6);
}
#[test]
fn test_lie_group_equivariant_dimension_mismatch() {
let attn = LieGroupEquivariantAttention::new(LieGroupType::SO3, 4, 6);
let query = vec![0.5; 5]; let result = attn.compute(&query, &[], &[]);
assert!(result.is_err());
}
#[test]
fn test_manifold_type_enum() {
let pb = ManifoldType::PoincareBall { curvature: 1.0 };
let lr = ManifoldType::Lorentz { curvature: 2.0 };
let sp = ManifoldType::Sphere;
let pr = ManifoldType::Product(vec![pb.clone(), sp.clone()]);
assert_eq!(pb, ManifoldType::PoincareBall { curvature: 1.0 });
assert_ne!(pb, lr);
match pr {
ManifoldType::Product(components) => assert_eq!(components.len(), 2),
_ => panic!("expected Product"),
}
}
#[test]
fn test_spherical_projection() {
let v = vec![3.0, 4.0];
let proj = project_to_sphere(&v);
let n: f32 = proj.iter().map(|&x| x * x).sum::<f32>().sqrt();
assert!((n - 1.0).abs() < 1e-5);
}
#[test]
fn test_spherical_geodesic() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let dist = spherical_geodesic(&a, &b);
assert!((dist - std::f32::consts::FRAC_PI_2).abs() < 1e-5);
}
#[test]
fn test_spherical_geodesic_same_point() {
let a = vec![1.0, 0.0, 0.0];
let dist = spherical_geodesic(&a, &a);
assert!(dist.abs() < 1e-5);
}
#[test]
fn test_sigmoid_bounds() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
assert!(sigmoid(10.0) > 0.99);
assert!(sigmoid(-10.0) < 0.01);
}
#[test]
fn test_poincare_project_inside() {
let x = vec![0.1, 0.2];
let proj = poincare_project(&x, 1.0);
assert_eq!(proj, x); }
#[test]
fn test_poincare_project_outside() {
let x = vec![0.8, 0.8]; let proj = poincare_project(&x, 1.0);
let nsq = norm_sq(&proj);
assert!(nsq < 1.0, "projected norm^2 = {} should be < 1.0", nsq);
}
}