use crate::BoxError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SphericalPoint {
coords: Vec<f32>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SphericalRelation {
axis: Vec<f32>,
angle: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SphericalEmbedding {
entities: Vec<SphericalPoint>,
relations: Vec<SphericalRelation>,
dim: usize,
}
impl SphericalPoint {
pub fn new(coords: Vec<f32>) -> Result<Self, BoxError> {
validate_finite(&coords)?;
let coords = l2_normalize(coords)?;
Ok(Self { coords })
}
#[must_use]
pub fn dim(&self) -> usize {
self.coords.len()
}
pub fn coords(&self) -> &[f32] {
&self.coords
}
}
impl SphericalRelation {
pub fn new(axis: Vec<f32>, angle: f32) -> Result<Self, BoxError> {
if !angle.is_finite() {
return Err(BoxError::InvalidBounds {
dim: 0,
min: angle as f64,
max: angle as f64,
});
}
validate_finite(&axis)?;
let axis = l2_normalize(axis)?;
Ok(Self { axis, angle })
}
#[must_use]
pub fn identity(dim: usize) -> Self {
let mut axis = vec![0.0f32; dim];
if dim > 0 {
axis[0] = 1.0;
}
Self { axis, angle: 0.0 }
}
#[must_use]
pub fn dim(&self) -> usize {
self.axis.len()
}
pub fn axis(&self) -> &[f32] {
&self.axis
}
#[must_use]
pub fn angle(&self) -> f32 {
self.angle
}
}
impl SphericalEmbedding {
pub fn new(
entities: Vec<SphericalPoint>,
relations: Vec<SphericalRelation>,
dim: usize,
) -> Result<Self, BoxError> {
for (i, e) in entities.iter().enumerate() {
if e.dim() != dim {
return Err(BoxError::DimensionMismatch {
expected: dim,
actual: e.coords.len(),
});
}
let _ = i;
}
for (i, r) in relations.iter().enumerate() {
if r.dim() != dim {
return Err(BoxError::DimensionMismatch {
expected: dim,
actual: r.axis.len(),
});
}
let _ = i;
}
Ok(Self {
entities,
relations,
dim,
})
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
pub fn num_entities(&self) -> usize {
self.entities.len()
}
#[must_use]
pub fn num_relations(&self) -> usize {
self.relations.len()
}
pub fn entities(&self) -> &[SphericalPoint] {
&self.entities
}
pub fn relations(&self) -> &[SphericalRelation] {
&self.relations
}
}
pub fn rotate(
point: &SphericalPoint,
relation: &SphericalRelation,
) -> Result<SphericalPoint, BoxError> {
let d = point.dim();
if relation.dim() != d {
return Err(BoxError::DimensionMismatch {
expected: d,
actual: relation.dim(),
});
}
let v = &point.coords;
let k = &relation.axis;
let angle = relation.angle;
let cos_a = angle.cos();
let sin_a = angle.sin();
let k_dot_v: f32 = k.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
let mut result = vec![0.0f32; d];
if d == 3 {
let cross = [
k[1] * v[2] - k[2] * v[1],
k[2] * v[0] - k[0] * v[2],
k[0] * v[1] - k[1] * v[0],
];
for i in 0..3 {
result[i] = v[i] * cos_a + cross[i] * sin_a + k[i] * k_dot_v * (1.0 - cos_a);
}
} else {
let mut v_perp = vec![0.0f32; d];
for i in 0..d {
v_perp[i] = v[i] - k[i] * k_dot_v;
}
let perp_norm = v_perp.iter().map(|x| x * x).sum::<f32>().sqrt();
if perp_norm < 1e-12 {
for i in 0..d {
result[i] = v[i];
}
} else {
for i in 0..d {
result[i] = k[i] * (k_dot_v + perp_norm * sin_a) + v_perp[i] * cos_a;
}
}
}
let result = l2_normalize(result).unwrap_or_else(|_| point.coords.clone());
Ok(SphericalPoint { coords: result })
}
pub fn score_triple(
head: &SphericalPoint,
relation: &SphericalRelation,
tail: &SphericalPoint,
) -> Result<f64, BoxError> {
if head.dim() != tail.dim() {
return Err(BoxError::DimensionMismatch {
expected: head.dim(),
actual: tail.dim(),
});
}
let rotated = rotate(head, relation)?;
let dot: f32 = rotated
.coords
.iter()
.zip(tail.coords.iter())
.map(|(a, b)| a * b)
.sum();
let clamped = dot.clamp(-1.0, 1.0);
Ok(clamped.acos() as f64)
}
pub fn score_batch(
heads: &[SphericalPoint],
relations: &[SphericalRelation],
tails: &[SphericalPoint],
) -> Result<Vec<f64>, BoxError> {
if heads.len() != relations.len() || heads.len() != tails.len() {
return Err(BoxError::DimensionMismatch {
expected: heads.len(),
actual: relations.len().min(tails.len()),
});
}
heads
.iter()
.zip(relations.iter())
.zip(tails.iter())
.map(|((h, r), t)| score_triple(h, r, t))
.collect()
}
pub fn project_to_sphere(v: Vec<f32>) -> Result<SphericalPoint, BoxError> {
SphericalPoint::new(v)
}
fn l2_normalize(mut v: Vec<f32>) -> Result<Vec<f32>, BoxError> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-12 {
return Err(BoxError::InvalidBounds {
dim: 0,
min: 0.0,
max: 0.0,
});
}
for x in &mut v {
*x /= norm;
}
Ok(v)
}
fn validate_finite(v: &[f32]) -> Result<(), BoxError> {
for (i, &x) in v.iter().enumerate() {
if !x.is_finite() {
return Err(BoxError::InvalidBounds {
dim: i,
min: x as f64,
max: x as f64,
});
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
fn point(v: &[f32]) -> SphericalPoint {
SphericalPoint::new(v.to_vec()).unwrap()
}
#[test]
fn identity_rotation_score_is_zero() {
let h = point(&[1.0, 0.0, 0.0]);
let r = SphericalRelation::identity(3);
let score = score_triple(&h, &r, &h).unwrap();
assert!(
score.abs() < 1e-6,
"score(h, identity, h) = {score}, expected 0"
);
}
#[test]
fn score_is_nonnegative() {
let h = point(&[1.0, 0.0, 0.0]);
let t = point(&[0.0, 1.0, 0.0]);
let r = SphericalRelation::identity(3);
let score = score_triple(&h, &r, &t).unwrap();
assert!(score >= -1e-9, "score should be non-negative, got {score}");
}
#[test]
fn orthogonal_vectors_score_pi_over_2() {
let h = point(&[1.0, 0.0, 0.0]);
let t = point(&[0.0, 1.0, 0.0]);
let r = SphericalRelation::identity(3);
let score = score_triple(&h, &r, &t).unwrap();
assert!(
(score - PI / 2.0).abs() < 1e-5,
"score(orthogonal, identity) = {score}, expected pi/2 = {}",
PI / 2.0
);
}
#[test]
fn antipodal_vectors_score_pi() {
let h = point(&[1.0, 0.0, 0.0]);
let t = point(&[-1.0, 0.0, 0.0]);
let r = SphericalRelation::identity(3);
let score = score_triple(&h, &r, &t).unwrap();
assert!(
(score - PI).abs() < 1e-5,
"score(antipodal, identity) = {score}, expected pi"
);
}
#[test]
fn projection_preserves_direction_sets_unit_norm() {
let v = vec![3.0, 4.0, 0.0];
let p = project_to_sphere(v).unwrap();
let norm: f32 = p.coords().iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-6,
"projected norm = {norm}, expected 1.0"
);
assert!(
(p.coords()[0] - 0.6).abs() < 1e-6,
"x = {}, expected 0.6",
p.coords()[0]
);
assert!(
(p.coords()[1] - 0.8).abs() < 1e-6,
"y = {}, expected 0.8",
p.coords()[1]
);
}
#[test]
fn rotation_preserves_unit_norm() {
let h = point(&[1.0, 0.0, 0.0]);
let r = SphericalRelation::new(vec![0.0, 0.0, 1.0], 1.0).unwrap();
let rotated = rotate(&h, &r).unwrap();
let norm: f32 = rotated.coords().iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"rotated norm = {norm}, expected 1.0"
);
}
#[test]
fn rotation_by_pi_over_2_moves_x_to_y() {
let h = point(&[1.0, 0.0, 0.0]);
let r = SphericalRelation::new(vec![0.0, 0.0, 1.0], std::f32::consts::FRAC_PI_2).unwrap();
let rotated = rotate(&h, &r).unwrap();
let t = point(&[0.0, 1.0, 0.0]);
let score = score_triple(&h, &r, &t).unwrap();
assert!(
score.abs() < 1e-4,
"rotating x by pi/2 around z should reach y, score = {score}"
);
assert!(
(rotated.coords()[0]).abs() < 1e-5,
"rotated x = {}, expected ~0",
rotated.coords()[0]
);
assert!(
(rotated.coords()[1] - 1.0).abs() < 1e-5,
"rotated y = {}, expected ~1",
rotated.coords()[1]
);
}
#[test]
fn batch_scoring_matches_individual() {
let h1 = point(&[1.0, 0.0, 0.0]);
let h2 = point(&[0.0, 1.0, 0.0]);
let r = SphericalRelation::identity(3);
let t1 = point(&[1.0, 0.0, 0.0]);
let t2 = point(&[0.0, 0.0, 1.0]);
let individual = vec![
score_triple(&h1, &r, &t1).unwrap(),
score_triple(&h2, &r, &t2).unwrap(),
];
let batch = score_batch(&[h1, h2], &[r.clone(), r], &[t1, t2]).unwrap();
for (i, (a, b)) in individual.iter().zip(batch.iter()).enumerate() {
assert!((a - b).abs() < 1e-9, "batch[{i}] = {b}, individual = {a}");
}
}
#[test]
fn dimension_mismatch_errors() {
let h = point(&[1.0, 0.0, 0.0]);
let t = point(&[1.0, 0.0]);
let r = SphericalRelation::identity(3);
assert!(score_triple(&h, &r, &t).is_err());
let r_bad = SphericalRelation::identity(2);
assert!(score_triple(&h, &r_bad, &h).is_err());
}
#[test]
fn rejects_zero_vector() {
assert!(SphericalPoint::new(vec![0.0, 0.0, 0.0]).is_err());
}
#[test]
fn rejects_non_finite() {
assert!(SphericalPoint::new(vec![f32::NAN, 1.0]).is_err());
assert!(SphericalPoint::new(vec![f32::INFINITY, 1.0]).is_err());
assert!(SphericalRelation::new(vec![1.0, 0.0], f32::NAN).is_err());
}
#[test]
fn embedding_model_construction() {
let entities = vec![point(&[1.0, 0.0, 0.0]), point(&[0.0, 1.0, 0.0])];
let relations = vec![SphericalRelation::identity(3)];
let model = SphericalEmbedding::new(entities, relations, 3);
assert!(model.is_ok());
let m = model.unwrap();
assert_eq!(m.num_entities(), 2);
assert_eq!(m.num_relations(), 1);
assert_eq!(m.dim(), 3);
}
#[test]
fn embedding_model_rejects_dim_mismatch() {
let entities = vec![point(&[1.0, 0.0])]; let relations = vec![SphericalRelation::identity(3)]; assert!(SphericalEmbedding::new(entities, relations, 3).is_err());
}
#[test]
fn score_symmetry_with_identity() {
let h = point(&[1.0, 0.0, 0.0]);
let t = point(&[0.6, 0.8, 0.0]);
let r = SphericalRelation::identity(3);
let s1 = score_triple(&h, &r, &t).unwrap();
let s2 = score_triple(&t, &r, &h).unwrap();
assert!(
(s1 - s2).abs() < 1e-6,
"identity relation should give symmetric scores: {s1} vs {s2}"
);
}
#[test]
fn higher_dimensional_works() {
let d = 50;
let mut v1 = vec![0.0f32; d];
let mut v2 = vec![0.0f32; d];
v1[0] = 1.0;
v2[1] = 1.0;
let h = point(&v1);
let t = point(&v2);
let r = SphericalRelation::identity(d);
let score = score_triple(&h, &r, &t).unwrap();
assert!(
(score - PI / 2.0).abs() < 1e-5,
"50d orthogonal score = {score}, expected pi/2"
);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
fn arb_point(dim: usize) -> impl Strategy<Value = SphericalPoint> {
prop::collection::vec(-10.0f32..10.0, dim)
.prop_filter_map("non-zero vector", move |coords| {
SphericalPoint::new(coords).ok()
})
}
fn arb_relation(dim: usize) -> impl Strategy<Value = SphericalRelation> {
(
prop::collection::vec(-10.0f32..10.0, dim),
-std::f32::consts::PI..std::f32::consts::PI,
)
.prop_filter_map("non-zero axis", move |(axis, angle)| {
SphericalRelation::new(axis, angle).ok()
})
}
proptest! {
#[test]
fn prop_score_nonnegative(
h in arb_point(4),
r in arb_relation(4),
t in arb_point(4),
) {
let score = score_triple(&h, &r, &t).unwrap();
prop_assert!(score >= -1e-9, "score should be >= 0, got {score}");
}
#[test]
fn prop_score_at_most_pi(
h in arb_point(4),
r in arb_relation(4),
t in arb_point(4),
) {
let score = score_triple(&h, &r, &t).unwrap();
prop_assert!(
score <= std::f64::consts::PI + 1e-6,
"score should be <= pi, got {score}"
);
}
#[test]
fn prop_identity_self_score_zero(
h in arb_point(4),
) {
let r = SphericalRelation::identity(4);
let score = score_triple(&h, &r, &h).unwrap();
prop_assert!(
score.abs() < 1e-3,
"score(h, identity, h) = {score}, expected ~0"
);
}
#[test]
fn prop_rotation_preserves_norm(
h in arb_point(4),
r in arb_relation(4),
) {
let rotated = rotate(&h, &r).unwrap();
let norm: f32 = rotated.coords().iter().map(|x| x * x).sum::<f32>().sqrt();
prop_assert!(
(norm - 1.0).abs() < 1e-4,
"rotated norm = {norm}, expected 1.0"
);
}
#[test]
fn prop_projection_idempotent(
v in prop::collection::vec(-10.0f32..10.0, 4)
.prop_filter("non-zero", |v| v.iter().map(|x| x*x).sum::<f32>() > 1e-12)
) {
let p1 = project_to_sphere(v).unwrap();
let p2 = project_to_sphere(p1.coords().to_vec()).unwrap();
for (a, b) in p1.coords().iter().zip(p2.coords().iter()) {
prop_assert!(
(a - b).abs() < 1e-5,
"projection should be idempotent: {a} vs {b}"
);
}
}
}
}