use crate::error::{RecsysError, RecsysResult};
use crate::handle::LcgRng;
fn tanh_vec(x: &mut [f32]) {
for v in x.iter_mut() {
*v = v.tanh();
}
}
fn matvec(w: &[f32], v: &[f32], d: usize) -> Vec<f32> {
(0..d)
.map(|r| {
w[r * d..(r + 1) * d]
.iter()
.zip(v.iter())
.map(|(&wij, &vj)| wij * vj)
.sum::<f32>()
})
.collect()
}
fn vec_add(a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect()
}
fn vec_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
fn softmax_inplace(v: &mut [f32]) {
let max = v
.iter()
.copied()
.fold(f32::NEG_INFINITY, |acc, x| acc.max(x));
let mut sum = 0.0_f32;
for x in v.iter_mut() {
*x = (*x - max).exp();
sum += *x;
}
let inv = 1.0 / (sum + 1e-10);
for x in v.iter_mut() {
*x *= inv;
}
}
#[derive(Debug, Clone)]
pub struct KgatConfig {
pub embed_dim: usize,
pub n_entities: usize,
pub n_relations: usize,
pub n_layers: usize,
}
pub struct Kgat {
pub cfg: KgatConfig,
pub entity_emb: Vec<f32>,
pub relation_emb: Vec<f32>,
pub layer_w: Vec<Vec<f32>>,
}
impl Kgat {
pub fn new(cfg: KgatConfig, rng: &mut LcgRng) -> RecsysResult<Self> {
if cfg.embed_dim == 0 {
return Err(RecsysError::InvalidEmbeddingDim { d: 0 });
}
if cfg.n_entities == 0 {
return Err(RecsysError::InvalidConfig {
msg: "n_entities must be >= 1".into(),
});
}
if cfg.n_relations == 0 {
return Err(RecsysError::InvalidConfig {
msg: "n_relations must be >= 1".into(),
});
}
if cfg.n_layers == 0 {
return Err(RecsysError::InvalidConfig {
msg: "n_layers must be >= 1".into(),
});
}
let d = cfg.embed_dim;
let scale = (1.0 / d as f32).sqrt();
let w_scale = (2.0 / d as f32).sqrt();
let entity_emb: Vec<f32> = (0..cfg.n_entities * d)
.map(|_| rng.next_normal() * scale)
.collect();
let relation_emb: Vec<f32> = (0..cfg.n_relations * d)
.map(|_| rng.next_normal() * scale)
.collect();
let layer_w: Vec<Vec<f32>> = (0..cfg.n_layers)
.map(|_| (0..d * d).map(|_| rng.next_normal() * w_scale).collect())
.collect();
Ok(Self {
cfg,
entity_emb,
relation_emb,
layer_w,
})
}
pub fn attention_score(&self, head: usize, relation: usize, tail: usize) -> RecsysResult<f32> {
self.check_triple(head, relation, tail)?;
let w = self.layer_w.first().ok_or(RecsysError::Internal {
msg: "no projection layer".into(),
})?;
let d = self.cfg.embed_dim;
let entity_emb = &self.entity_emb;
let relation_emb = &self.relation_emb;
let e_h = &entity_emb[head * d..(head + 1) * d];
let e_t = &entity_emb[tail * d..(tail + 1) * d];
let e_r = &relation_emb[relation * d..(relation + 1) * d];
let mut left = matvec(w, &vec_add(e_h, e_r), d);
let mut right = matvec(w, e_t, d);
tanh_vec(&mut left);
tanh_vec(&mut right);
Ok(vec_dot(&left, &right))
}
fn layer_score(
&self,
layer: usize,
head: usize,
relation: usize,
tail: usize,
embeddings: &[f32],
) -> RecsysResult<f32> {
let d = self.cfg.embed_dim;
let w = self.layer_w.get(layer).ok_or(RecsysError::Internal {
msg: "layer index out of range".into(),
})?;
let e_h = &embeddings[head * d..(head + 1) * d];
let e_t = &embeddings[tail * d..(tail + 1) * d];
let e_r = &self.relation_emb[relation * d..(relation + 1) * d];
let mut left = matvec(w, &vec_add(e_h, e_r), d);
let mut right = matvec(w, e_t, d);
tanh_vec(&mut left);
tanh_vec(&mut right);
Ok(vec_dot(&left, &right))
}
pub fn propagate(
&self,
embeddings: &[f32],
triples: &[(usize, usize, usize)],
) -> RecsysResult<Vec<f32>> {
let d = self.cfg.embed_dim;
let n = self.cfg.n_entities;
if embeddings.len() != n * d {
return Err(RecsysError::DimensionMismatch {
expected: n * d,
got: embeddings.len(),
});
}
for &(h, r, t) in triples {
self.check_triple(h, r, t)?;
}
self.propagate_layer(0, embeddings, triples)
}
fn propagate_layer(
&self,
layer: usize,
embeddings: &[f32],
triples: &[(usize, usize, usize)],
) -> RecsysResult<Vec<f32>> {
let d = self.cfg.embed_dim;
let n = self.cfg.n_entities;
let mut head_edges: Vec<Vec<(usize, usize)>> = vec![Vec::new(); n];
for &(h, r, t) in triples {
head_edges[h].push((r, t));
}
let mut out = vec![0.0_f32; n * d];
for (h, edges) in head_edges.iter().enumerate() {
if edges.is_empty() {
continue;
}
let mut scores = Vec::with_capacity(edges.len());
for &(r, t) in edges {
scores.push(self.layer_score(layer, h, r, t, embeddings)?);
}
softmax_inplace(&mut scores);
for (idx, &(_, t)) in edges.iter().enumerate() {
let a = scores[idx];
let e_t = &embeddings[t * d..(t + 1) * d];
for k in 0..d {
out[h * d + k] += a * e_t[k];
}
}
}
Ok(out)
}
pub fn forward(&self, triples: &[(usize, usize, usize)]) -> RecsysResult<Vec<f32>> {
for &(h, r, t) in triples {
self.check_triple(h, r, t)?;
}
let d = self.cfg.embed_dim;
let n = self.cfg.n_entities;
let mut concatenated: Vec<f32> = Vec::with_capacity(n * d * (self.cfg.n_layers + 1));
concatenated.extend_from_slice(&self.entity_emb);
let mut current = self.entity_emb.clone();
for layer in 0..self.cfg.n_layers {
let next = self.propagate_layer(layer, ¤t, triples)?;
concatenated.extend_from_slice(&next);
current = next;
}
let total_d = d * (self.cfg.n_layers + 1);
let mut entity_major = vec![0.0_f32; n * total_d];
for layer in 0..(self.cfg.n_layers + 1) {
for e in 0..n {
let src = layer * n * d + e * d;
let dst = e * total_d + layer * d;
entity_major[dst..dst + d].copy_from_slice(&concatenated[src..src + d]);
}
}
Ok(entity_major)
}
pub fn score(&self, user: usize, item: usize, concatenated: &[f32]) -> RecsysResult<f32> {
let n = self.cfg.n_entities;
let total_d = self.cfg.embed_dim * (self.cfg.n_layers + 1);
if concatenated.len() != n * total_d {
return Err(RecsysError::DimensionMismatch {
expected: n * total_d,
got: concatenated.len(),
});
}
if user >= n {
return Err(RecsysError::ItemOutOfBounds { idx: user, n });
}
if item >= n {
return Err(RecsysError::ItemOutOfBounds { idx: item, n });
}
let u = &concatenated[user * total_d..(user + 1) * total_d];
let i = &concatenated[item * total_d..(item + 1) * total_d];
Ok(vec_dot(u, i))
}
#[must_use]
pub fn n_params(&self) -> usize {
self.entity_emb.len()
+ self.relation_emb.len()
+ self.layer_w.iter().map(Vec::len).sum::<usize>()
}
fn check_triple(&self, head: usize, relation: usize, tail: usize) -> RecsysResult<()> {
if head >= self.cfg.n_entities {
return Err(RecsysError::ItemOutOfBounds {
idx: head,
n: self.cfg.n_entities,
});
}
if tail >= self.cfg.n_entities {
return Err(RecsysError::ItemOutOfBounds {
idx: tail,
n: self.cfg.n_entities,
});
}
if relation >= self.cfg.n_relations {
return Err(RecsysError::ItemOutOfBounds {
idx: relation,
n: self.cfg.n_relations,
});
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
fn make_rng() -> LcgRng {
LcgRng::new(42)
}
fn default_cfg() -> KgatConfig {
KgatConfig {
embed_dim: 4,
n_entities: 6,
n_relations: 3,
n_layers: 2,
}
}
#[test]
fn attention_score_is_finite() {
let mut rng = make_rng();
let model = Kgat::new(default_cfg(), &mut rng).expect("value should be present");
let s = model
.attention_score(0, 0, 1)
.expect("attention_score should succeed");
assert!(s.is_finite(), "score must be finite, got {s}");
}
#[test]
fn propagate_output_length() {
let mut rng = make_rng();
let cfg = default_cfg();
let model = Kgat::new(cfg.clone(), &mut rng).expect("value should be present");
let triples = vec![(0, 0, 1), (0, 1, 2), (1, 0, 3), (2, 1, 4)];
let out = model
.propagate(&model.entity_emb, &triples)
.expect("propagate should succeed");
assert_eq!(out.len(), cfg.n_entities * cfg.embed_dim);
}
#[test]
fn isolated_entity_zero_row() {
let mut rng = make_rng();
let cfg = default_cfg();
let model = Kgat::new(cfg.clone(), &mut rng).expect("value should be present");
let triples = vec![(0, 0, 1), (0, 1, 2), (1, 0, 3)];
let out = model
.propagate(&model.entity_emb, &triples)
.expect("propagate should succeed");
let d = cfg.embed_dim;
let row = &out[5 * d..6 * d];
for &v in row {
assert!(v.abs() < 1e-7, "isolated row must be zero, got {v}");
}
}
#[test]
fn single_triple_aggregation_matches_hand_math() {
let mut rng = make_rng();
let cfg = default_cfg();
let model = Kgat::new(cfg.clone(), &mut rng).expect("value should be present");
let triples = vec![(0, 0, 1)];
let out = model
.propagate(&model.entity_emb, &triples)
.expect("propagate should succeed");
let d = cfg.embed_dim;
let head_row = &out[0..d];
let tail_row = &model.entity_emb[d..2 * d];
for k in 0..d {
assert!(
(head_row[k] - tail_row[k]).abs() < 1e-5,
"single-edge aggregation must equal tail (k={k})"
);
}
}
#[test]
fn forward_output_length() {
let mut rng = make_rng();
let cfg = default_cfg();
let model = Kgat::new(cfg.clone(), &mut rng).expect("value should be present");
let triples = vec![(0, 0, 1), (0, 1, 2), (1, 0, 3), (2, 1, 4), (3, 2, 5)];
let out = model.forward(&triples).expect("forward should succeed");
assert_eq!(
out.len(),
cfg.n_entities * cfg.embed_dim * (cfg.n_layers + 1)
);
}
#[test]
fn score_returns_finite() {
let mut rng = make_rng();
let cfg = default_cfg();
let model = Kgat::new(cfg.clone(), &mut rng).expect("value should be present");
let triples = vec![(0, 0, 1), (0, 1, 2), (1, 0, 3), (2, 1, 4)];
let concat = model.forward(&triples).expect("forward should succeed");
let s = model.score(0, 1, &concat).expect("score should succeed");
assert!(s.is_finite(), "score must be finite, got {s}");
}
#[test]
fn deterministic_given_seed() {
let mut rng_a = LcgRng::new(13);
let mut rng_b = LcgRng::new(13);
let model_a = Kgat::new(default_cfg(), &mut rng_a).expect("value should be present");
let model_b = Kgat::new(default_cfg(), &mut rng_b).expect("value should be present");
let triples = vec![(0, 0, 1), (0, 1, 2), (1, 0, 3)];
let out_a = model_a.forward(&triples).expect("forward should succeed");
let out_b = model_b.forward(&triples).expect("forward should succeed");
assert_eq!(out_a.len(), out_b.len());
for (a, b) in out_a.iter().zip(out_b.iter()) {
assert!((a - b).abs() < 1e-6, "same seed must yield same output");
}
}
#[test]
fn changing_relation_changes_attention() {
let mut rng = make_rng();
let model = Kgat::new(default_cfg(), &mut rng).expect("value should be present");
let s0 = model
.attention_score(0, 0, 1)
.expect("attention_score should succeed");
let s1 = model
.attention_score(0, 1, 1)
.expect("attention_score should succeed");
assert!(
(s0 - s1).abs() > 1e-7,
"different relations should yield different scores (got {s0}, {s1})"
);
}
#[test]
fn asymmetric_head_tail_direction() {
let mut rng = make_rng();
let cfg = default_cfg();
let model = Kgat::new(cfg.clone(), &mut rng).expect("value should be present");
let triples = vec![(0, 0, 1)];
let out = model
.propagate(&model.entity_emb, &triples)
.expect("propagate should succeed");
let d = cfg.embed_dim;
let tail_row = &out[d..2 * d];
for &v in tail_row {
assert!(
v.abs() < 1e-7,
"tail row must remain zero (no edge directed at it), got {v}"
);
}
}
#[test]
fn err_triple_head_out_of_range() {
let mut rng = make_rng();
let model = Kgat::new(default_cfg(), &mut rng).expect("value should be present");
assert!(matches!(
model.attention_score(999, 0, 1),
Err(RecsysError::ItemOutOfBounds { .. })
));
}
#[test]
fn err_triple_relation_out_of_range() {
let mut rng = make_rng();
let model = Kgat::new(default_cfg(), &mut rng).expect("value should be present");
assert!(matches!(
model.attention_score(0, 999, 1),
Err(RecsysError::ItemOutOfBounds { .. })
));
}
#[test]
fn err_triple_tail_out_of_range() {
let mut rng = make_rng();
let model = Kgat::new(default_cfg(), &mut rng).expect("value should be present");
let triples = vec![(0, 0, 999)];
assert!(matches!(
model.propagate(&model.entity_emb, &triples),
Err(RecsysError::ItemOutOfBounds { .. })
));
}
#[test]
fn err_embed_dim_zero() {
let mut rng = make_rng();
let cfg = KgatConfig {
embed_dim: 0,
n_entities: 4,
n_relations: 2,
n_layers: 1,
};
assert!(matches!(
Kgat::new(cfg, &mut rng),
Err(RecsysError::InvalidEmbeddingDim { d: 0 })
));
}
#[test]
fn err_n_entities_zero() {
let mut rng = make_rng();
let cfg = KgatConfig {
embed_dim: 4,
n_entities: 0,
n_relations: 2,
n_layers: 1,
};
assert!(matches!(
Kgat::new(cfg, &mut rng),
Err(RecsysError::InvalidConfig { .. })
));
}
#[test]
fn err_n_relations_zero() {
let mut rng = make_rng();
let cfg = KgatConfig {
embed_dim: 4,
n_entities: 4,
n_relations: 0,
n_layers: 1,
};
assert!(matches!(
Kgat::new(cfg, &mut rng),
Err(RecsysError::InvalidConfig { .. })
));
}
#[test]
fn err_n_layers_zero() {
let mut rng = make_rng();
let cfg = KgatConfig {
embed_dim: 4,
n_entities: 4,
n_relations: 2,
n_layers: 0,
};
assert!(matches!(
Kgat::new(cfg, &mut rng),
Err(RecsysError::InvalidConfig { .. })
));
}
#[test]
fn n_params_positive_and_matches_closed_form() {
let mut rng = make_rng();
let cfg = default_cfg();
let model = Kgat::new(cfg.clone(), &mut rng).expect("value should be present");
let n = model.n_params();
let d = cfg.embed_dim;
let expected = cfg.n_entities * d + cfg.n_relations * d + cfg.n_layers * d * d;
assert!(n > 0, "n_params must be > 0, got {n}");
assert_eq!(n, expected, "n_params should match closed-form count");
}
#[test]
fn two_vs_one_layer_propagation_differ() {
let mut rng_two = LcgRng::new(7);
let mut rng_one = LcgRng::new(7);
let cfg_two = KgatConfig {
embed_dim: 4,
n_entities: 5,
n_relations: 2,
n_layers: 2,
};
let cfg_one = KgatConfig {
embed_dim: 4,
n_entities: 5,
n_relations: 2,
n_layers: 1,
};
let model_two = Kgat::new(cfg_two, &mut rng_two).expect("new should succeed");
let model_one = Kgat::new(cfg_one, &mut rng_one).expect("new should succeed");
let triples = vec![(0, 0, 1), (1, 0, 2), (2, 1, 3), (3, 1, 4)];
let out_two = model_two.forward(&triples).expect("forward should succeed");
let out_one = model_one.forward(&triples).expect("forward should succeed");
assert_ne!(
out_two.len(),
out_one.len(),
"different depth → different layout"
);
let d = 4_usize;
let row_l1_no_edges = &out_two[d..2 * d]; let _ = row_l1_no_edges;
}
#[test]
fn softmax_normalisation_sums_to_one() {
let mut rng = make_rng();
let cfg = KgatConfig {
embed_dim: 4,
n_entities: 5,
n_relations: 3,
n_layers: 1,
};
let model = Kgat::new(cfg.clone(), &mut rng).expect("value should be present");
let d = cfg.embed_dim;
let mut embeddings = model.entity_emb.clone();
for t in [1usize, 2, 3] {
for k in 0..d {
embeddings[t * d + k] = 1.0;
}
}
let triples = vec![(0, 0, 1), (0, 1, 2), (0, 2, 3)];
let out = model
.propagate(&embeddings, &triples)
.expect("propagate should succeed");
let head_row = &out[0..d];
for &v in head_row {
assert!(
(v - 1.0).abs() < 1e-5,
"softmax-weighted average of identical ones-vectors must equal 1.0, got {v}"
);
}
}
#[test]
fn changing_relation_changes_propagated_row() {
let mut rng = make_rng();
let cfg = default_cfg();
let model = Kgat::new(cfg.clone(), &mut rng).expect("value should be present");
let d = cfg.embed_dim;
let triples_a = vec![(0, 0, 1), (0, 0, 2)];
let triples_b = vec![(0, 1, 1), (0, 2, 2)];
let out_a = model
.propagate(&model.entity_emb, &triples_a)
.expect("propagate should succeed");
let out_b = model
.propagate(&model.entity_emb, &triples_b)
.expect("propagate should succeed");
let diff: f32 = out_a[0..d]
.iter()
.zip(out_b[0..d].iter())
.map(|(&x, &y)| (x - y).abs())
.sum();
assert!(
diff > 1e-5,
"different relations should yield different rows (got diff {diff})"
);
}
#[test]
fn err_concat_wrong_length_in_score() {
let mut rng = make_rng();
let model = Kgat::new(default_cfg(), &mut rng).expect("value should be present");
let bad = vec![0.0_f32; 3];
assert!(matches!(
model.score(0, 1, &bad),
Err(RecsysError::DimensionMismatch { .. })
));
}
#[test]
fn err_propagate_embeddings_wrong_length() {
let mut rng = make_rng();
let model = Kgat::new(default_cfg(), &mut rng).expect("value should be present");
let bad = vec![0.0_f32; 7];
let triples = vec![(0, 0, 1)];
assert!(matches!(
model.propagate(&bad, &triples),
Err(RecsysError::DimensionMismatch { .. })
));
}
}