#![warn(missing_docs)]
pub mod dataset;
pub mod eval;
pub mod io;
pub mod query;
#[cfg(feature = "candle")]
pub mod train;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch {
expected: usize,
actual: usize,
},
#[error(transparent)]
Io(#[from] std::io::Error),
}
pub trait Scorer: Sync {
fn score(&self, head: usize, relation: usize, tail: usize) -> f32;
fn num_entities(&self) -> usize;
fn score_all_tails(&self, head: usize, relation: usize) -> Vec<f32> {
(0..self.num_entities())
.map(|t| self.score(head, relation, t))
.collect()
}
fn score_all_heads(&self, relation: usize, tail: usize) -> Vec<f32> {
(0..self.num_entities())
.map(|h| self.score(h, relation, tail))
.collect()
}
fn top_k_tails(&self, head: usize, relation: usize, k: usize) -> Vec<(usize, f32)> {
let mut scored: Vec<(usize, f32)> = self
.score_all_tails(head, relation)
.into_iter()
.enumerate()
.collect();
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(k);
scored
}
fn top_k_heads(&self, relation: usize, tail: usize, k: usize) -> Vec<(usize, f32)> {
let mut scored: Vec<(usize, f32)> = self
.score_all_heads(relation, tail)
.into_iter()
.enumerate()
.collect();
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(k);
scored
}
fn score_all_relations(&self, head: usize, tail: usize, num_relations: usize) -> Vec<f32> {
(0..num_relations)
.map(|r| self.score(head, r, tail))
.collect()
}
fn top_k_relations(
&self,
head: usize,
tail: usize,
num_relations: usize,
k: usize,
) -> Vec<(usize, f32)> {
let mut scored: Vec<(usize, f32)> = self
.score_all_relations(head, tail, num_relations)
.into_iter()
.enumerate()
.collect();
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(k);
scored
}
}
pub struct TransE {
entities: Vec<f32>,
relations: Vec<f32>,
dim: usize,
norm: u32,
}
impl TransE {
#[cfg(feature = "rand")]
pub fn new(num_entities: usize, num_relations: usize, dim: usize) -> Self {
let mut rng = rand::rng();
let scale = 6.0_f32 / (dim as f32).sqrt();
Self::from_vecs(
init_vecs(&mut rng, num_entities, dim, scale),
init_vecs(&mut rng, num_relations, dim, scale),
dim,
)
}
pub fn from_vecs(entities: Vec<Vec<f32>>, relations: Vec<Vec<f32>>, dim: usize) -> Self {
assert_dims(&entities, dim, "entity");
assert_dims(&relations, dim, "relation");
Self {
entities: flatten(&entities),
relations: flatten(&relations),
dim,
norm: 2,
}
}
pub fn from_vecs_with_norm(
entities: Vec<Vec<f32>>,
relations: Vec<Vec<f32>>,
dim: usize,
norm: u32,
) -> Self {
assert_dims(&entities, dim, "entity");
assert_dims(&relations, dim, "relation");
Self {
entities: flatten(&entities),
relations: flatten(&relations),
dim,
norm,
}
}
pub fn entities_flat(&self) -> &[f32] {
&self.entities
}
pub fn relations_flat(&self) -> &[f32] {
&self.relations
}
pub fn entity_vecs(&self) -> Vec<Vec<f32>> {
unflatten(&self.entities, self.dim)
}
pub fn relation_vecs(&self) -> Vec<Vec<f32>> {
unflatten(&self.relations, self.dim)
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn norm(&self) -> u32 {
self.norm
}
pub fn score_triple(&self, head: usize, relation: usize, tail: usize) -> f32 {
let h = row(&self.entities, head, self.dim);
let r = row(&self.relations, relation, self.dim);
let t = row(&self.entities, tail, self.dim);
match self.norm {
1 => {
let mut dist = 0.0_f64;
for i in 0..self.dim {
dist += (h[i] as f64 + r[i] as f64 - t[i] as f64).abs();
}
dist as f32
}
_ => {
let mut dist_sq = 0.0_f64;
for i in 0..self.dim {
let d = h[i] as f64 + r[i] as f64 - t[i] as f64;
dist_sq += d * d;
}
dist_sq.sqrt() as f32
}
}
}
}
impl Scorer for TransE {
fn score(&self, head: usize, relation: usize, tail: usize) -> f32 {
self.score_triple(head, relation, tail)
}
fn num_entities(&self) -> usize {
self.entities.len() / self.dim
}
fn score_all_tails(&self, head: usize, relation: usize) -> Vec<f32> {
let h = row(&self.entities, head, self.dim);
let r = row(&self.relations, relation, self.dim);
let dim = self.dim;
let norm = self.norm;
let n = self.num_entities();
let mut hr = vec![0.0_f64; dim];
for i in 0..dim {
hr[i] = h[i] as f64 + r[i] as f64;
}
(0..n)
.map(|ti| {
let t = row(&self.entities, ti, dim);
if norm == 1 {
let mut dist = 0.0_f64;
for i in 0..dim {
dist += (hr[i] - t[i] as f64).abs();
}
dist as f32
} else {
let mut dist_sq = 0.0_f64;
for i in 0..dim {
let d = hr[i] - t[i] as f64;
dist_sq += d * d;
}
dist_sq.sqrt() as f32
}
})
.collect()
}
fn score_all_heads(&self, relation: usize, tail: usize) -> Vec<f32> {
let r = row(&self.relations, relation, self.dim);
let t = row(&self.entities, tail, self.dim);
let dim = self.dim;
let norm = self.norm;
let n = self.num_entities();
let mut neg_rt = vec![0.0_f64; dim];
for i in 0..dim {
neg_rt[i] = r[i] as f64 - t[i] as f64;
}
(0..n)
.map(|hi| {
let h = row(&self.entities, hi, dim);
if norm == 1 {
let mut dist = 0.0_f64;
for i in 0..dim {
dist += (h[i] as f64 + neg_rt[i]).abs();
}
dist as f32
} else {
let mut dist_sq = 0.0_f64;
for i in 0..dim {
let d = h[i] as f64 + neg_rt[i];
dist_sq += d * d;
}
dist_sq.sqrt() as f32
}
})
.collect()
}
}
pub struct RotatE {
entities: Vec<f32>,
relation_angles: Vec<f32>,
dim: usize,
gamma: f32,
}
impl RotatE {
#[cfg(feature = "rand")]
pub fn new(num_entities: usize, num_relations: usize, dim: usize, gamma: f32) -> Self {
use rand::Rng;
let mut rng = rand::rng();
let entity_scale = gamma / (dim as f32).sqrt();
let entities = init_vecs(&mut rng, num_entities, dim * 2, entity_scale);
let relation_angles: Vec<Vec<f32>> = (0..num_relations)
.map(|_| {
(0..dim)
.map(|_| rng.random_range(-std::f32::consts::PI..std::f32::consts::PI))
.collect()
})
.collect();
Self {
entities: flatten(&entities),
relation_angles: flatten(&relation_angles),
dim,
gamma,
}
}
pub fn from_vecs(
entities: Vec<Vec<f32>>,
relation_angles: Vec<Vec<f32>>,
dim: usize,
gamma: f32,
) -> Self {
assert_dims(&entities, dim * 2, "entity (re+im)");
assert_dims(&relation_angles, dim, "relation angle");
Self {
entities: flatten(&entities),
relation_angles: flatten(&relation_angles),
dim,
gamma,
}
}
pub fn entities_flat(&self) -> &[f32] {
&self.entities
}
pub fn relation_angles_flat(&self) -> &[f32] {
&self.relation_angles
}
pub fn entity_vecs(&self) -> Vec<Vec<f32>> {
unflatten(&self.entities, self.dim * 2)
}
pub fn relation_angle_vecs(&self) -> Vec<Vec<f32>> {
unflatten(&self.relation_angles, self.dim)
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn gamma(&self) -> f32 {
self.gamma
}
pub fn score_triple(&self, head: usize, relation: usize, tail: usize) -> f32 {
let h = row(&self.entities, head, self.dim * 2);
let r = row(&self.relation_angles, relation, self.dim);
let t = row(&self.entities, tail, self.dim * 2);
let dim = self.dim;
let mut dist_sq = 0.0_f64;
for i in 0..dim {
let h_re = h[i] as f64;
let h_im = h[dim + i] as f64;
let (r_sin, r_cos) = (r[i] as f64).sin_cos();
let t_re = t[i] as f64;
let t_im = t[dim + i] as f64;
let hr_re = h_re * r_cos - h_im * r_sin;
let hr_im = h_re * r_sin + h_im * r_cos;
let d_re = hr_re - t_re;
let d_im = hr_im - t_im;
dist_sq += d_re * d_re + d_im * d_im;
}
dist_sq.sqrt() as f32
}
}
impl Scorer for RotatE {
fn score(&self, head: usize, relation: usize, tail: usize) -> f32 {
self.score_triple(head, relation, tail)
}
fn num_entities(&self) -> usize {
self.entities.len() / (self.dim * 2)
}
fn score_all_tails(&self, head: usize, relation: usize) -> Vec<f32> {
let h = row(&self.entities, head, self.dim * 2);
let r = row(&self.relation_angles, relation, self.dim);
let dim = self.dim;
let n = self.num_entities();
let mut hr_re = vec![0.0_f64; dim];
let mut hr_im = vec![0.0_f64; dim];
for i in 0..dim {
let h_re = h[i] as f64;
let h_im = h[dim + i] as f64;
let (r_sin, r_cos) = (r[i] as f64).sin_cos();
hr_re[i] = h_re * r_cos - h_im * r_sin;
hr_im[i] = h_re * r_sin + h_im * r_cos;
}
(0..n)
.map(|ti| {
let t = row(&self.entities, ti, dim * 2);
let mut dist_sq = 0.0_f64;
for i in 0..dim {
let d_re = hr_re[i] - t[i] as f64;
let d_im = hr_im[i] - t[dim + i] as f64;
dist_sq += d_re * d_re + d_im * d_im;
}
dist_sq.sqrt() as f32
})
.collect()
}
fn score_all_heads(&self, relation: usize, tail: usize) -> Vec<f32> {
let r = row(&self.relation_angles, relation, self.dim);
let t = row(&self.entities, tail, self.dim * 2);
let dim = self.dim;
let n = self.num_entities();
let mut tr_re = vec![0.0_f64; dim];
let mut tr_im = vec![0.0_f64; dim];
for i in 0..dim {
let t_re = t[i] as f64;
let t_im = t[dim + i] as f64;
let (r_sin, r_cos) = (r[i] as f64).sin_cos();
tr_re[i] = t_re * r_cos + t_im * r_sin;
tr_im[i] = t_im * r_cos - t_re * r_sin;
}
(0..n)
.map(|hi| {
let h = row(&self.entities, hi, dim * 2);
let mut dist_sq = 0.0_f64;
for i in 0..dim {
let d_re = h[i] as f64 - tr_re[i];
let d_im = h[dim + i] as f64 - tr_im[i];
dist_sq += d_re * d_re + d_im * d_im;
}
dist_sq.sqrt() as f32
})
.collect()
}
}
pub struct ComplEx {
entities: Vec<f32>,
relations: Vec<f32>,
dim: usize,
}
impl ComplEx {
#[cfg(feature = "rand")]
pub fn new(num_entities: usize, num_relations: usize, dim: usize) -> Self {
let mut rng = rand::rng();
let scale = (6.0_f32 / dim as f32).sqrt();
Self::from_vecs(
init_vecs(&mut rng, num_entities, dim * 2, scale),
init_vecs(&mut rng, num_relations, dim * 2, scale),
dim,
)
}
pub fn from_vecs(entities: Vec<Vec<f32>>, relations: Vec<Vec<f32>>, dim: usize) -> Self {
assert_dims(&entities, dim * 2, "entity (re+im)");
assert_dims(&relations, dim * 2, "relation (re+im)");
Self {
entities: flatten(&entities),
relations: flatten(&relations),
dim,
}
}
pub fn entities_flat(&self) -> &[f32] {
&self.entities
}
pub fn relations_flat(&self) -> &[f32] {
&self.relations
}
pub fn entity_vecs(&self) -> Vec<Vec<f32>> {
unflatten(&self.entities, self.dim * 2)
}
pub fn relation_vecs(&self) -> Vec<Vec<f32>> {
unflatten(&self.relations, self.dim * 2)
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn score_triple(&self, head: usize, relation: usize, tail: usize) -> f32 {
let h = row(&self.entities, head, self.dim * 2);
let r = row(&self.relations, relation, self.dim * 2);
let t = row(&self.entities, tail, self.dim * 2);
let dim = self.dim;
let mut dot = 0.0_f64;
for i in 0..dim {
let h_re = h[i] as f64;
let h_im = h[dim + i] as f64;
let r_re = r[i] as f64;
let r_im = r[dim + i] as f64;
let t_re = t[i] as f64;
let t_im = t[dim + i] as f64;
let hr_re = h_re * r_re - h_im * r_im;
let hr_im = h_re * r_im + h_im * r_re;
dot += hr_re * t_re + hr_im * t_im;
}
dot as f32
}
}
impl Scorer for ComplEx {
fn score(&self, head: usize, relation: usize, tail: usize) -> f32 {
-self.score_triple(head, relation, tail)
}
fn num_entities(&self) -> usize {
self.entities.len() / (self.dim * 2)
}
fn score_all_tails(&self, head: usize, relation: usize) -> Vec<f32> {
let h = row(&self.entities, head, self.dim * 2);
let r = row(&self.relations, relation, self.dim * 2);
let dim = self.dim;
let n = self.num_entities();
let mut hr_re = vec![0.0_f64; dim];
let mut hr_im = vec![0.0_f64; dim];
for i in 0..dim {
let h_re = h[i] as f64;
let h_im = h[dim + i] as f64;
let r_re = r[i] as f64;
let r_im = r[dim + i] as f64;
hr_re[i] = h_re * r_re - h_im * r_im;
hr_im[i] = h_re * r_im + h_im * r_re;
}
(0..n)
.map(|ti| {
let t = row(&self.entities, ti, dim * 2);
let mut dot = 0.0_f64;
for i in 0..dim {
dot += hr_re[i] * t[i] as f64 + hr_im[i] * t[dim + i] as f64;
}
-(dot as f32)
})
.collect()
}
fn score_all_heads(&self, relation: usize, tail: usize) -> Vec<f32> {
let r = row(&self.relations, relation, self.dim * 2);
let t = row(&self.entities, tail, self.dim * 2);
let dim = self.dim;
let n = self.num_entities();
let mut rc_re = vec![0.0_f64; dim];
let mut rc_im = vec![0.0_f64; dim];
for i in 0..dim {
let r_re = r[i] as f64;
let r_im = r[dim + i] as f64;
let t_re = t[i] as f64;
let t_im = t[dim + i] as f64;
rc_re[i] = r_re * t_re + r_im * t_im;
rc_im[i] = r_im * t_re - r_re * t_im;
}
(0..n)
.map(|hi| {
let h = row(&self.entities, hi, dim * 2);
let mut dot = 0.0_f64;
for i in 0..dim {
dot += h[i] as f64 * rc_re[i] - h[dim + i] as f64 * rc_im[i];
}
-(dot as f32)
})
.collect()
}
}
pub struct DistMult {
entities: Vec<f32>,
relations: Vec<f32>,
dim: usize,
}
impl DistMult {
#[cfg(feature = "rand")]
pub fn new(num_entities: usize, num_relations: usize, dim: usize) -> Self {
let mut rng = rand::rng();
let scale = (6.0_f32 / dim as f32).sqrt();
Self::from_vecs(
init_vecs(&mut rng, num_entities, dim, scale),
init_vecs(&mut rng, num_relations, dim, scale),
dim,
)
}
pub fn from_vecs(entities: Vec<Vec<f32>>, relations: Vec<Vec<f32>>, dim: usize) -> Self {
assert_dims(&entities, dim, "entity");
assert_dims(&relations, dim, "relation");
Self {
entities: flatten(&entities),
relations: flatten(&relations),
dim,
}
}
pub fn entities_flat(&self) -> &[f32] {
&self.entities
}
pub fn relations_flat(&self) -> &[f32] {
&self.relations
}
pub fn entity_vecs(&self) -> Vec<Vec<f32>> {
unflatten(&self.entities, self.dim)
}
pub fn relation_vecs(&self) -> Vec<Vec<f32>> {
unflatten(&self.relations, self.dim)
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn score_triple(&self, head: usize, relation: usize, tail: usize) -> f32 {
let h = row(&self.entities, head, self.dim);
let r = row(&self.relations, relation, self.dim);
let t = row(&self.entities, tail, self.dim);
let mut dot = 0.0_f64;
for i in 0..self.dim {
dot += h[i] as f64 * r[i] as f64 * t[i] as f64;
}
dot as f32
}
}
impl Scorer for DistMult {
fn score(&self, head: usize, relation: usize, tail: usize) -> f32 {
-self.score_triple(head, relation, tail)
}
fn num_entities(&self) -> usize {
self.entities.len() / self.dim
}
fn score_all_tails(&self, head: usize, relation: usize) -> Vec<f32> {
let h = row(&self.entities, head, self.dim);
let r = row(&self.relations, relation, self.dim);
let dim = self.dim;
let n = self.num_entities();
let mut hr = vec![0.0_f64; dim];
for i in 0..dim {
hr[i] = h[i] as f64 * r[i] as f64;
}
(0..n)
.map(|ti| {
let t = row(&self.entities, ti, dim);
let mut dot = 0.0_f64;
for i in 0..dim {
dot += hr[i] * t[i] as f64;
}
-(dot as f32)
})
.collect()
}
fn score_all_heads(&self, relation: usize, tail: usize) -> Vec<f32> {
let r = row(&self.relations, relation, self.dim);
let t = row(&self.entities, tail, self.dim);
let dim = self.dim;
let n = self.num_entities();
let mut rt = vec![0.0_f64; dim];
for i in 0..dim {
rt[i] = r[i] as f64 * t[i] as f64;
}
(0..n)
.map(|hi| {
let h = row(&self.entities, hi, dim);
let mut dot = 0.0_f64;
for i in 0..dim {
dot += rt[i] * h[i] as f64;
}
-(dot as f32)
})
.collect()
}
}
#[inline]
fn row(data: &[f32], i: usize, stride: usize) -> &[f32] {
&data[i * stride..(i + 1) * stride]
}
fn flatten(vecs: &[Vec<f32>]) -> Vec<f32> {
let total: usize = vecs.iter().map(|v| v.len()).sum();
let mut flat = Vec::with_capacity(total);
for v in vecs {
flat.extend_from_slice(v);
}
flat
}
fn unflatten(flat: &[f32], stride: usize) -> Vec<Vec<f32>> {
flat.chunks_exact(stride).map(|c| c.to_vec()).collect()
}
fn assert_dims(vecs: &[Vec<f32>], expected: usize, label: &str) {
for (i, v) in vecs.iter().enumerate() {
assert_eq!(
v.len(),
expected,
"{label} embedding {i} has length {}, expected {expected}",
v.len()
);
}
}
#[cfg(feature = "rand")]
fn init_vecs(rng: &mut impl rand::Rng, count: usize, len: usize, scale: f32) -> Vec<Vec<f32>> {
(0..count)
.map(|_| (0..len).map(|_| rng.random_range(-scale..scale)).collect())
.collect()
}
pub struct EnsembledScorer {
models: Vec<Box<dyn Scorer>>,
}
impl EnsembledScorer {
pub fn new(models: Vec<Box<dyn Scorer>>) -> Self {
assert!(!models.is_empty(), "ensemble requires at least one model");
let n = models[0].num_entities();
for (i, m) in models.iter().enumerate().skip(1) {
assert_eq!(
m.num_entities(),
n,
"model {i} has {} entities, expected {n}",
m.num_entities()
);
}
Self { models }
}
}
impl Scorer for EnsembledScorer {
fn score(&self, head: usize, relation: usize, tail: usize) -> f32 {
let sum: f32 = self
.models
.iter()
.map(|m| m.score(head, relation, tail))
.sum();
sum / self.models.len() as f32
}
fn num_entities(&self) -> usize {
self.models[0].num_entities()
}
fn score_all_tails(&self, head: usize, relation: usize) -> Vec<f32> {
let n = self.num_entities();
let k = self.models.len() as f32;
let mut avg = vec![0.0_f32; n];
for m in &self.models {
let scores = m.score_all_tails(head, relation);
for (i, &s) in scores.iter().enumerate() {
avg[i] += s;
}
}
for v in &mut avg {
*v /= k;
}
avg
}
fn score_all_heads(&self, relation: usize, tail: usize) -> Vec<f32> {
let n = self.num_entities();
let k = self.models.len() as f32;
let mut avg = vec![0.0_f32; n];
for m in &self.models {
let scores = m.score_all_heads(relation, tail);
for (i, &s) in scores.iter().enumerate() {
avg[i] += s;
}
}
for v in &mut avg {
*v /= k;
}
avg
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn transe_manual_score() {
let model = TransE::from_vecs(
vec![vec![1.0, 0.0], vec![1.0, 1.0]],
vec![vec![0.0, 1.0]],
2,
);
let score = model.score_triple(0, 0, 1);
assert!((score - 0.0).abs() < 1e-6, "expected 0, got {score}");
}
#[test]
fn transe_manual_nonzero() {
let model = TransE::from_vecs(
vec![vec![3.0, 0.0], vec![0.0, 4.0]],
vec![vec![0.0, 0.0]],
2,
);
let score = model.score_triple(0, 0, 1);
assert!((score - 5.0).abs() < 1e-5, "expected 5, got {score}");
}
#[test]
fn transe_scorer_trait() {
let model = TransE::new(10, 3, 8);
let scorer: &dyn Scorer = &model;
assert_eq!(scorer.num_entities(), 10);
let s = scorer.score(0, 0, 1);
assert!(s.is_finite());
assert!(s >= 0.0);
}
#[test]
#[should_panic(expected = "entity embedding 0 has length 3, expected 2")]
fn transe_rejects_bad_dims() {
TransE::from_vecs(vec![vec![1.0, 2.0, 3.0]], vec![vec![1.0, 2.0]], 2);
}
#[test]
fn rotate_identity_rotation() {
let model = RotatE::from_vecs(
vec![vec![1.0, 0.0], vec![1.0, 0.0]],
vec![vec![0.0]], 1,
12.0,
);
let score = model.score_triple(0, 0, 1);
assert!((score - 0.0).abs() < 1e-6, "expected 0, got {score}");
}
#[test]
fn rotate_90_degrees() {
use std::f32::consts::FRAC_PI_2;
let model = RotatE::from_vecs(
vec![
vec![1.0, 0.0], vec![0.0, 1.0], ],
vec![vec![FRAC_PI_2]], 1,
12.0,
);
let score = model.score_triple(0, 0, 1);
assert!(score < 1e-5, "expected ~0, got {score}");
}
#[test]
fn rotate_scorer_trait() {
let model = RotatE::new(10, 3, 8, 12.0);
let scorer: &dyn Scorer = &model;
assert_eq!(scorer.num_entities(), 10);
let s = scorer.score(0, 0, 1);
assert!(s.is_finite());
assert!(s >= 0.0);
}
#[test]
fn rotate_contiguous_layout_dim2() {
use std::f32::consts::FRAC_PI_2;
let model = RotatE::from_vecs(
vec![
vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 0.0, 1.0, 0.0], ],
vec![vec![FRAC_PI_2, 0.0]], 2,
12.0,
);
let score = model.score_triple(0, 0, 1);
assert!(score < 1e-5, "expected ~0, got {score}");
}
#[test]
fn complex_manual_score() {
let model = ComplEx::from_vecs(vec![vec![1.0, 0.0]], vec![vec![1.0, 0.0]], 1);
let score = model.score_triple(0, 0, 0);
assert!((score - 1.0).abs() < 1e-6, "expected 1.0, got {score}");
}
#[test]
fn complex_imaginary_parts() {
let model = ComplEx::from_vecs(
vec![
vec![0.0, 1.0], vec![1.0, 0.0], ],
vec![vec![0.0, 1.0]],
1,
);
let score = model.score_triple(0, 0, 1);
assert!((score - (-1.0)).abs() < 1e-6, "expected -1.0, got {score}");
}
#[test]
fn complex_scorer_negates() {
let model = ComplEx::from_vecs(vec![vec![1.0, 0.0]], vec![vec![1.0, 0.0]], 1);
let raw = model.score_triple(0, 0, 0);
let via_scorer = model.score(0, 0, 0);
assert!((via_scorer - (-raw)).abs() < 1e-6);
}
#[test]
fn distmult_manual_score() {
let model = DistMult::from_vecs(
vec![vec![2.0, 3.0], vec![4.0, 5.0]],
vec![vec![1.0, -1.0]],
2,
);
let score = model.score_triple(0, 0, 1);
assert!((score - (-7.0)).abs() < 1e-5, "expected -7.0, got {score}");
}
#[test]
fn distmult_symmetric() {
let model = DistMult::new(10, 3, 16);
let s1 = model.score_triple(0, 0, 1);
let s2 = model.score_triple(1, 0, 0);
assert!(
(s1 - s2).abs() < 1e-5,
"DistMult should be symmetric: {s1} vs {s2}"
);
}
#[test]
fn distmult_scorer_negates() {
let model = DistMult::from_vecs(vec![vec![1.0], vec![2.0]], vec![vec![3.0]], 1);
let raw = model.score_triple(0, 0, 1);
let via_scorer = model.score(0, 0, 1);
assert!((via_scorer - (-raw)).abs() < 1e-6);
}
#[test]
fn score_all_tails_length() {
let model = TransE::new(10, 3, 8);
let scores = model.score_all_tails(0, 0);
assert_eq!(scores.len(), 10);
assert!(scores.iter().all(|s| s.is_finite()));
}
#[test]
fn top_k_tails_sorted() {
let model = TransE::new(20, 3, 8);
let top = model.top_k_tails(0, 0, 5);
assert_eq!(top.len(), 5);
for w in top.windows(2) {
assert!(w[0].1 <= w[1].1, "top_k should be sorted ascending");
}
}
#[test]
fn top_k_heads_sorted() {
let model = TransE::new(20, 3, 8);
let top = model.top_k_heads(0, 0, 5);
assert_eq!(top.len(), 5);
for w in top.windows(2) {
assert!(w[0].1 <= w[1].1);
}
}
#[test]
fn transe_l1_vs_l2_differ() {
let entities = vec![vec![3.0, 0.0], vec![0.0, 4.0]];
let relations = vec![vec![0.0, 0.0]];
let l1 = TransE::from_vecs_with_norm(entities.clone(), relations.clone(), 2, 1);
let l2 = TransE::from_vecs_with_norm(entities, relations, 2, 2);
let s1 = l1.score_triple(0, 0, 1); let s2 = l2.score_triple(0, 0, 1); assert!((s1 - 7.0).abs() < 1e-4, "L1 score should be 7, got {s1}");
assert!((s2 - 5.0).abs() < 1e-4, "L2 score should be 5, got {s2}");
}
#[test]
fn transe_l1_score_all_tails_consistent() {
let model = TransE::from_vecs_with_norm(
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]],
vec![vec![0.0, 0.0]],
2,
1,
);
let all = model.score_all_tails(0, 0);
for (t, &score) in all.iter().enumerate() {
let individual = model.score(0, 0, t);
assert!(
(score - individual).abs() < 1e-5,
"L1 score_all_tails[{t}]={score} vs score()={individual}"
);
}
}
#[test]
fn relation_prediction_returns_correct_count() {
let model = DistMult::new(10, 5, 8);
let scores = model.score_all_relations(0, 1, 5);
assert_eq!(scores.len(), 5);
assert!(scores.iter().all(|s| s.is_finite()));
}
#[test]
fn top_k_relations_sorted() {
let model = ComplEx::new(10, 5, 8);
let top = model.top_k_relations(0, 1, 5, 3);
assert_eq!(top.len(), 3);
for w in top.windows(2) {
assert!(w[0].1 <= w[1].1);
}
}
}