#![allow(missing_docs)]
use crate::BoxError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SphericalCap {
center: Vec<f32>,
angular_radius: f32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SphericalCapRelation {
axis: Vec<f32>,
angle: f32,
angle_scale: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SphericalCapEmbedding {
entities: Vec<SphericalCap>,
relations: Vec<SphericalCapRelation>,
dim: usize,
}
impl SphericalCap {
pub fn new(center: Vec<f32>, angular_radius: f32) -> Result<Self, BoxError> {
if !angular_radius.is_finite()
|| angular_radius <= 0.0
|| angular_radius > std::f32::consts::PI
{
return Err(BoxError::InvalidBounds {
dim: 0,
min: 0.0,
max: angular_radius as f64,
});
}
for (i, &c) in center.iter().enumerate() {
if !c.is_finite() {
return Err(BoxError::InvalidBounds {
dim: i,
min: c as f64,
max: c as f64,
});
}
}
let norm = center.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-12 {
return Err(BoxError::InvalidBounds {
dim: 0,
min: 0.0,
max: 0.0,
});
}
let center: Vec<f32> = center.iter().map(|x| x / norm).collect();
Ok(Self {
center,
angular_radius,
})
}
pub fn from_log_tan_half(center: Vec<f32>, log_tan_half: f32) -> Result<Self, BoxError> {
if !log_tan_half.is_finite() {
return Err(BoxError::InvalidBounds {
dim: 0,
min: log_tan_half as f64,
max: log_tan_half as f64,
});
}
let theta = 2.0 * log_tan_half.exp().atan();
Self::new(center, theta)
}
#[must_use]
pub fn dim(&self) -> usize {
self.center.len()
}
pub fn center(&self) -> &[f32] {
&self.center
}
#[must_use]
pub fn angular_radius(&self) -> f32 {
self.angular_radius
}
#[must_use]
pub fn log_tan_half(&self) -> f32 {
(self.angular_radius / 2.0).tan().ln()
}
#[must_use]
pub fn area_fraction(&self) -> f32 {
let d = self.center.len();
match d {
2 => self.angular_radius / std::f32::consts::PI,
3 => (1.0 - self.angular_radius.cos()) / 2.0,
_ => {
let sin_theta = self.angular_radius.sin();
if sin_theta < 1e-12 {
return 0.0;
}
let log_sin = (d as f32 - 1.0) * sin_theta.ln();
if log_sin < -50.0 {
0.0
} else {
log_sin.exp().min(1.0)
}
}
}
}
pub fn center_mut(&mut self) -> &mut [f32] {
&mut self.center
}
pub fn set_log_tan_half(&mut self, log_tan_half: f32) {
let clamped = log_tan_half.clamp(-10.0, 10.0);
self.angular_radius =
(2.0 * clamped.exp().atan()).clamp(0.001, std::f32::consts::PI - 0.001);
}
}
impl SphericalCapRelation {
pub fn new(axis: Vec<f32>, angle: f32, angle_scale: f32) -> Result<Self, BoxError> {
if !angle_scale.is_finite() || angle_scale <= 0.0 {
return Err(BoxError::InvalidBounds {
dim: 0,
min: 0.0,
max: angle_scale as f64,
});
}
for (i, &a) in axis.iter().enumerate() {
if !a.is_finite() {
return Err(BoxError::InvalidBounds {
dim: i,
min: a as f64,
max: a as f64,
});
}
}
if !angle.is_finite() {
return Err(BoxError::InvalidBounds {
dim: 0,
min: angle as f64,
max: angle as f64,
});
}
let norm = axis.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-12 {
return Err(BoxError::InvalidBounds {
dim: 0,
min: 0.0,
max: 0.0,
});
}
let axis: Vec<f32> = axis.iter().map(|x| x / norm).collect();
Ok(Self {
axis,
angle,
angle_scale,
})
}
#[must_use]
pub fn identity(dim: usize) -> Self {
let mut axis = vec![0.0; dim];
if dim > 0 {
axis[0] = 1.0;
}
Self {
axis,
angle: 0.0,
angle_scale: 1.0,
}
}
pub fn apply(&self, cap: &SphericalCap) -> Result<SphericalCap, BoxError> {
if self.axis.len() != cap.center.len() {
return Err(BoxError::DimensionMismatch {
expected: cap.center.len(),
actual: self.axis.len(),
});
}
let new_center = rotate_vector(&cap.center, &self.axis, self.angle);
let new_radius = cap.angular_radius * self.angle_scale;
let new_radius = new_radius.clamp(0.001, std::f32::consts::PI);
SphericalCap::new(new_center, new_radius)
}
pub fn axis_mut(&mut self) -> &mut [f32] {
&mut self.axis
}
pub fn axis(&self) -> &[f32] {
&self.axis
}
#[must_use]
pub fn log_scale(&self) -> f32 {
self.angle_scale.ln()
}
pub fn set_log_scale(&mut self, log_scale: f32) {
self.angle_scale = log_scale.clamp(-5.0, 5.0).exp();
}
}
impl SphericalCapEmbedding {
pub fn new(
entities: Vec<SphericalCap>,
relations: Vec<SphericalCapRelation>,
dim: usize,
) -> Result<Self, BoxError> {
for e in &entities {
if e.dim() != dim {
return Err(BoxError::DimensionMismatch {
expected: dim,
actual: e.dim(),
});
}
}
for r in &relations {
if r.axis.len() != dim {
return Err(BoxError::DimensionMismatch {
expected: dim,
actual: r.axis.len(),
});
}
}
Ok(Self {
entities,
relations,
dim,
})
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
pub fn num_entities(&self) -> usize {
self.entities.len()
}
#[must_use]
pub fn num_relations(&self) -> usize {
self.relations.len()
}
pub fn entities(&self) -> &[SphericalCap] {
&self.entities
}
pub fn relations(&self) -> &[SphericalCapRelation] {
&self.relations
}
}
pub fn geodesic_distance(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
(dot.clamp(-1.0, 1.0)).acos()
}
pub fn containment_prob(
inner: &SphericalCap,
outer: &SphericalCap,
k: f32,
) -> Result<f32, BoxError> {
if inner.dim() != outer.dim() {
return Err(BoxError::DimensionMismatch {
expected: inner.dim(),
actual: outer.dim(),
});
}
let dist = geodesic_distance(&inner.center, &outer.center);
let margin = outer.angular_radius - dist - inner.angular_radius;
Ok(crate::utils::stable_sigmoid(k * margin))
}
pub fn surface_distance(a: &SphericalCap, b: &SphericalCap) -> Result<f32, BoxError> {
if a.dim() != b.dim() {
return Err(BoxError::DimensionMismatch {
expected: a.dim(),
actual: b.dim(),
});
}
let dist = geodesic_distance(&a.center, &b.center);
Ok((dist - a.angular_radius - b.angular_radius).max(0.0))
}
pub fn overlap_prob(a: &SphericalCap, b: &SphericalCap) -> Result<f32, BoxError> {
if a.dim() != b.dim() {
return Err(BoxError::DimensionMismatch {
expected: a.dim(),
actual: b.dim(),
});
}
let dist = geodesic_distance(&a.center, &b.center);
let sum_theta = a.angular_radius + b.angular_radius;
if sum_theta < 1e-12 {
return Ok(0.0);
}
let overlap_depth = (sum_theta - dist).max(0.0);
Ok(overlap_depth / sum_theta)
}
pub fn score_triple(
head: &SphericalCap,
relation: &SphericalCapRelation,
tail: &SphericalCap,
) -> Result<f32, BoxError> {
let transformed = relation.apply(head)?;
surface_distance(&transformed, tail)
}
fn rotate_vector(v: &[f32], axis: &[f32], angle: f32) -> Vec<f32> {
debug_assert_eq!(v.len(), axis.len());
let d = v.len();
let dot: f32 = v.iter().zip(axis.iter()).map(|(&x, &y)| x * y).sum();
let v_par: Vec<f32> = axis.iter().map(|&a| a * dot).collect();
let v_perp: Vec<f32> = v
.iter()
.zip(v_par.iter())
.map(|(&vi, &pi)| vi - pi)
.collect();
let perp_norm = v_perp.iter().map(|x| x * x).sum::<f32>().sqrt();
if perp_norm < 1e-12 {
return v.to_vec();
}
let cos_a = angle.cos();
let sin_a = angle.sin();
if d == 3 {
let cross = [
axis[1] * v[2] - axis[2] * v[1],
axis[2] * v[0] - axis[0] * v[2],
axis[0] * v[1] - axis[1] * v[0],
];
(0..3)
.map(|i| v[i] * cos_a + cross[i] * sin_a + axis[i] * dot * (1.0 - cos_a))
.collect()
} else {
let v_perp_hat: Vec<f32> = v_perp.iter().map(|&x| x / perp_norm).collect();
let w_dot: f32 = axis
.iter()
.zip(v_perp_hat.iter())
.map(|(&a, &vp)| a * vp)
.sum();
let mut w: Vec<f32> = axis
.iter()
.zip(v_perp_hat.iter())
.map(|(&a, &vp)| a - w_dot * vp)
.collect();
let w_norm = w.iter().map(|x| x * x).sum::<f32>().sqrt();
if w_norm < 1e-12 {
return v.to_vec();
}
for x in &mut w {
*x /= w_norm;
}
v_par
.iter()
.zip(v_perp_hat.iter())
.zip(w.iter())
.map(|((&vp, &vph), &wi)| vp + perp_norm * (cos_a * vph + sin_a * wi))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f32::consts::PI;
#[test]
fn cap_new_valid() {
let c = SphericalCap::new(vec![1.0, 0.0, 0.0], PI / 4.0).unwrap();
assert_eq!(c.dim(), 3);
assert!((c.angular_radius() - PI / 4.0).abs() < 1e-6);
let norm: f32 = c.center().iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn cap_normalizes_center() {
let c = SphericalCap::new(vec![3.0, 4.0, 0.0], 0.5).unwrap();
let norm: f32 = c.center().iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn cap_from_log_tan_half() {
let c = SphericalCap::from_log_tan_half(vec![1.0, 0.0], 0.0).unwrap();
assert!((c.angular_radius() - PI / 2.0).abs() < 1e-5);
}
#[test]
fn cap_rejects_zero_radius() {
assert!(SphericalCap::new(vec![1.0, 0.0], 0.0).is_err());
}
#[test]
fn cap_rejects_too_large_radius() {
assert!(SphericalCap::new(vec![1.0, 0.0], PI + 0.1).is_err());
}
#[test]
fn cap_rejects_zero_vector() {
assert!(SphericalCap::new(vec![0.0, 0.0], 0.5).is_err());
}
#[test]
fn cap_rejects_non_finite() {
assert!(SphericalCap::new(vec![f32::NAN, 0.0], 0.5).is_err());
assert!(SphericalCap::new(vec![1.0, 0.0], f32::NAN).is_err());
}
#[test]
fn containment_identical_is_half() {
let a = SphericalCap::new(vec![1.0, 0.0, 0.0], PI / 4.0).unwrap();
let p = containment_prob(&a, &a, 10.0).unwrap();
assert!((p - 0.5).abs() < 1e-4);
}
#[test]
fn containment_nested_is_near_one() {
let inner = SphericalCap::new(vec![1.0, 0.0, 0.0], 0.1).unwrap();
let outer = SphericalCap::new(vec![1.0, 0.0, 0.0], 1.0).unwrap();
let p = containment_prob(&inner, &outer, 10.0).unwrap();
assert!(p > 0.99, "nested containment = {p}, expected > 0.99");
}
#[test]
fn containment_disjoint_is_near_zero() {
let a = SphericalCap::new(vec![1.0, 0.0, 0.0], 0.1).unwrap();
let b = SphericalCap::new(vec![-1.0, 0.0, 0.0], 0.1).unwrap();
let p = containment_prob(&a, &b, 10.0).unwrap();
assert!(p < 1e-4, "disjoint containment = {p}, expected ~0");
}
#[test]
fn containment_dimension_mismatch() {
let a = SphericalCap::new(vec![1.0, 0.0], 0.5).unwrap();
let b = SphericalCap::new(vec![1.0, 0.0, 0.0], 0.5).unwrap();
assert!(containment_prob(&a, &b, 1.0).is_err());
}
#[test]
fn surface_distance_overlapping_is_zero() {
let a = SphericalCap::new(vec![1.0, 0.0, 0.0], 1.0).unwrap();
let b = SphericalCap::new(vec![0.0, 1.0, 0.0], 1.0).unwrap();
let d = surface_distance(&a, &b).unwrap();
assert!(d.abs() < 1e-6);
}
#[test]
fn surface_distance_disjoint_is_positive() {
let a = SphericalCap::new(vec![1.0, 0.0, 0.0], 0.1).unwrap();
let b = SphericalCap::new(vec![-1.0, 0.0, 0.0], 0.1).unwrap();
let d = surface_distance(&a, &b).unwrap();
assert!((d - (PI - 0.2)).abs() < 1e-4);
}
#[test]
fn surface_distance_identical_is_zero() {
let a = SphericalCap::new(vec![1.0, 0.0, 0.0], 0.5).unwrap();
let d = surface_distance(&a, &a).unwrap();
assert!(d.abs() < 1e-6);
}
#[test]
fn overlap_identical_is_one() {
let a = SphericalCap::new(vec![1.0, 0.0, 0.0], 0.5).unwrap();
let p = overlap_prob(&a, &a).unwrap();
assert!((p - 1.0).abs() < 1e-6);
}
#[test]
fn overlap_disjoint_is_zero() {
let a = SphericalCap::new(vec![1.0, 0.0, 0.0], 0.1).unwrap();
let b = SphericalCap::new(vec![-1.0, 0.0, 0.0], 0.1).unwrap();
let p = overlap_prob(&a, &b).unwrap();
assert!(p.abs() < 1e-6);
}
#[test]
fn overlap_symmetric() {
let a = SphericalCap::new(vec![1.0, 0.0, 0.0], 0.5).unwrap();
let b = SphericalCap::new(vec![0.0, 1.0, 0.0], 0.8).unwrap();
let p_ab = overlap_prob(&a, &b).unwrap();
let p_ba = overlap_prob(&b, &a).unwrap();
assert!((p_ab - p_ba).abs() < 1e-6);
}
#[test]
fn area_fraction_2d() {
let c = SphericalCap::new(vec![1.0, 0.0], PI / 2.0).unwrap();
let f = c.area_fraction();
assert!((f - 0.5).abs() < 1e-6, "half circle = 0.5, got {f}");
}
#[test]
fn area_fraction_3d_hemisphere() {
let c = SphericalCap::new(vec![1.0, 0.0, 0.0], PI / 2.0).unwrap();
let f = c.area_fraction();
assert!((f - 0.5).abs() < 1e-6, "hemisphere = 0.5, got {f}");
}
#[test]
fn relation_identity_preserves_cap() {
let c = SphericalCap::new(vec![1.0, 0.0, 0.0], 0.5).unwrap();
let r = SphericalCapRelation::identity(3);
let t = r.apply(&c).unwrap();
for (i, (&a, &b)) in c.center().iter().zip(t.center().iter()).enumerate() {
assert!((a - b).abs() < 1e-6, "center[{i}] changed: {a} -> {b}");
}
assert!((t.angular_radius() - c.angular_radius()).abs() < 1e-6);
}
#[test]
fn relation_rotation() {
let c = SphericalCap::new(vec![1.0, 0.0, 0.0], 0.5).unwrap();
let r = SphericalCapRelation::new(vec![0.0, 0.0, 1.0], PI / 2.0, 1.0).unwrap();
let t = r.apply(&c).unwrap();
assert!(t.center()[0].abs() < 1e-5);
assert!((t.center()[1] - 1.0).abs() < 1e-5);
assert!((t.angular_radius() - 0.5).abs() < 1e-6);
}
#[test]
fn relation_scaling() {
let c = SphericalCap::new(vec![1.0, 0.0], 0.5).unwrap();
let r = SphericalCapRelation::new(vec![0.0, 1.0], 0.0, 2.0).unwrap();
let t = r.apply(&c).unwrap();
assert!((t.angular_radius() - 1.0).abs() < 1e-5);
}
#[test]
fn relation_rejects_zero_scale() {
assert!(SphericalCapRelation::new(vec![1.0, 0.0], 0.0, 0.0).is_err());
}
#[test]
fn relation_dimension_mismatch() {
let c = SphericalCap::new(vec![1.0, 0.0], 0.5).unwrap();
let r = SphericalCapRelation::new(vec![0.0, 0.0, 1.0], 0.0, 1.0).unwrap();
assert!(r.apply(&c).is_err());
}
#[test]
fn score_triple_perfect_match() {
let h = SphericalCap::new(vec![1.0, 0.0, 0.0], 0.5).unwrap();
let r = SphericalCapRelation::identity(3);
let t = SphericalCap::new(vec![1.0, 0.0, 0.0], 0.5).unwrap();
let s = score_triple(&h, &r, &t).unwrap();
assert!(s.abs() < 1e-6);
}
#[test]
fn score_triple_mismatch() {
let h = SphericalCap::new(vec![1.0, 0.0, 0.0], 0.1).unwrap();
let r = SphericalCapRelation::identity(3);
let t = SphericalCap::new(vec![-1.0, 0.0, 0.0], 0.1).unwrap();
let s = score_triple(&h, &r, &t).unwrap();
assert!(s > 2.0, "mismatch score = {s}, expected > 2");
}
#[test]
fn embedding_model_construction() {
let entities = vec![
SphericalCap::new(vec![1.0, 0.0, 0.0], 0.5).unwrap(),
SphericalCap::new(vec![0.0, 1.0, 0.0], 0.3).unwrap(),
];
let relations = vec![SphericalCapRelation::identity(3)];
let model = SphericalCapEmbedding::new(entities, relations, 3).unwrap();
assert_eq!(model.num_entities(), 2);
assert_eq!(model.num_relations(), 1);
assert_eq!(model.dim(), 3);
}
#[test]
fn embedding_model_rejects_dim_mismatch() {
let entities = vec![SphericalCap::new(vec![1.0, 0.0], 0.5).unwrap()];
let relations = vec![SphericalCapRelation::identity(3)];
assert!(SphericalCapEmbedding::new(entities, relations, 2).is_err());
}
#[test]
fn geodesic_identical_is_zero() {
let a = [1.0, 0.0, 0.0];
let d = geodesic_distance(&a, &a);
assert!(d.abs() < 1e-6);
}
#[test]
fn geodesic_orthogonal_is_pi_half() {
let a = [1.0, 0.0, 0.0];
let b = [0.0, 1.0, 0.0];
let d = geodesic_distance(&a, &b);
assert!((d - PI / 2.0).abs() < 1e-6);
}
#[test]
fn geodesic_opposite_is_pi() {
let a = [1.0, 0.0, 0.0];
let b = [-1.0, 0.0, 0.0];
let d = geodesic_distance(&a, &b);
assert!((d - PI).abs() < 1e-6);
}
#[test]
fn sigmoid_large_positive() {
assert!((crate::utils::stable_sigmoid(100.0) - 1.0).abs() < 1e-4);
}
#[test]
fn sigmoid_large_negative() {
assert!(crate::utils::stable_sigmoid(-100.0).abs() < 1e-4);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
use std::f32::consts::PI;
fn arb_cap(dim: usize) -> impl Strategy<Value = SphericalCap> {
(
prop::collection::vec(-10.0f32..10.0, dim),
0.01f32..(PI - 0.01),
)
.prop_filter_map("valid cap", move |(center, theta)| {
SphericalCap::new(center, theta).ok()
})
}
fn arb_cap_pair(dim: usize) -> impl Strategy<Value = (SphericalCap, SphericalCap)> {
(arb_cap(dim), arb_cap(dim))
}
fn arb_relation(dim: usize) -> impl Strategy<Value = SphericalCapRelation> {
(
prop::collection::vec(-5.0f32..5.0, dim),
-PI..PI,
0.1f32..5.0,
)
.prop_filter_map("valid relation", move |(axis, angle, scale)| {
SphericalCapRelation::new(axis, angle, scale).ok()
})
}
proptest! {
#[test]
fn prop_containment_in_unit_interval(
(a, b) in arb_cap_pair(4)
) {
let p = containment_prob(&a, &b, 10.0).unwrap();
prop_assert!(p >= -1e-6, "containment_prob < 0: {p}");
prop_assert!(p <= 1.0 + 1e-6, "containment_prob > 1: {p}");
}
#[test]
fn prop_surface_distance_nonneg(
(a, b) in arb_cap_pair(4)
) {
let d = surface_distance(&a, &b).unwrap();
prop_assert!(d >= -1e-6, "surface_distance < 0: {d}");
}
#[test]
fn prop_overlap_in_unit_interval(
(a, b) in arb_cap_pair(4)
) {
let p = overlap_prob(&a, &b).unwrap();
prop_assert!(p >= -1e-6, "overlap_prob < 0: {p}");
prop_assert!(p <= 1.0 + 1e-6, "overlap_prob > 1: {p}");
}
#[test]
fn prop_overlap_symmetric(
(a, b) in arb_cap_pair(4)
) {
let p_ab = overlap_prob(&a, &b).unwrap();
let p_ba = overlap_prob(&b, &a).unwrap();
prop_assert!(
(p_ab - p_ba).abs() < 1e-5,
"overlap should be symmetric: {p_ab} != {p_ba}"
);
}
#[test]
fn prop_surface_distance_symmetric(
(a, b) in arb_cap_pair(4)
) {
let d_ab = surface_distance(&a, &b).unwrap();
let d_ba = surface_distance(&b, &a).unwrap();
prop_assert!(
(d_ab - d_ba).abs() < 1e-3,
"surface_distance should be symmetric: {d_ab} != {d_ba}"
);
}
#[test]
fn prop_relation_apply_preserves_dim(
c in arb_cap(4),
r in arb_relation(4)
) {
let t = r.apply(&c).unwrap();
prop_assert_eq!(t.dim(), c.dim());
prop_assert!(t.angular_radius() > 0.0, "radius should be positive");
prop_assert!(t.angular_radius() <= PI, "radius should be <= pi");
}
#[test]
fn prop_score_triple_nonneg(
h in arb_cap(4),
r in arb_relation(4),
t in arb_cap(4)
) {
let s = score_triple(&h, &r, &t).unwrap();
prop_assert!(s >= -1e-6, "score_triple < 0: {s}");
}
#[test]
fn prop_area_fraction_in_unit_interval(
c in arb_cap(4)
) {
let f = c.area_fraction();
prop_assert!(f >= 0.0, "area_fraction < 0: {f}");
prop_assert!(f <= 1.0 + 1e-5, "area_fraction > 1: {f}");
}
#[test]
fn prop_center_is_unit_vector(
c in arb_cap(4)
) {
let norm: f32 = c.center().iter().map(|x| x * x).sum::<f32>().sqrt();
prop_assert!((norm - 1.0).abs() < 1e-5, "center not unit: {norm}");
}
#[test]
fn prop_geodesic_in_range(
(a, b) in arb_cap_pair(4)
) {
let d = geodesic_distance(a.center(), b.center());
prop_assert!(d >= 0.0, "geodesic < 0: {d}");
prop_assert!(d <= PI + 1e-5, "geodesic > pi: {d}");
}
}
}