use std::collections::HashMap;
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::{Rng, RngExt};
use crate::gnn::rgcn::KgScorer;
fn xavier_uniform(rows: usize, cols: usize) -> Array2<f64> {
let mut rng = scirs2_core::random::rng();
let limit = (6.0_f64 / (rows + cols) as f64).sqrt();
Array2::from_shape_fn((rows, cols), |_| rng.random::<f64>() * 2.0 * limit - limit)
}
#[derive(Debug, Clone)]
pub struct RelationEmbedding {
pub table: Array2<f64>,
pub n_relations: usize,
pub dim: usize,
}
impl RelationEmbedding {
pub fn new(n_relations: usize, dim: usize) -> Self {
Self {
table: xavier_uniform(n_relations, dim),
n_relations,
dim,
}
}
pub fn get(&self, r: usize) -> Option<Array1<f64>> {
if r < self.n_relations {
Some(self.table.row(r).to_owned())
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct RotatEScoring {
pub entity_re: Array2<f64>,
pub entity_im: Array2<f64>,
pub relation_phase: Array2<f64>,
pub n_entities: usize,
pub n_relations: usize,
pub dim: usize,
}
impl RotatEScoring {
pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Self {
let mut rng = scirs2_core::random::rng();
let entity_re = xavier_uniform(n_entities, dim);
let entity_im = xavier_uniform(n_entities, dim);
let relation_phase = Array2::from_shape_fn((n_relations, dim), |_| {
(rng.random::<f64>() * 2.0 - 1.0) * std::f64::consts::PI
});
Self {
entity_re,
entity_im,
relation_phase,
n_entities,
n_relations,
dim,
}
}
pub fn score_triple(&self, h: usize, r: usize, t: usize) -> f64 {
let h_re = self.entity_re.row(h);
let h_im = self.entity_im.row(h);
let t_re = self.entity_re.row(t);
let t_im = self.entity_im.row(t);
let phase = self.relation_phase.row(r);
let mut norm_sq = 0.0_f64;
for k in 0..self.dim {
let cos_r = phase[k].cos();
let sin_r = phase[k].sin();
let diff_re = h_re[k] * cos_r - h_im[k] * sin_r - t_re[k];
let diff_im = h_re[k] * sin_r + h_im[k] * cos_r - t_im[k];
norm_sq += diff_re * diff_re + diff_im * diff_im;
}
-norm_sq.sqrt()
}
}
impl KgScorer for RotatEScoring {
fn score(&self, h: usize, r: usize, t: usize) -> f64 {
self.score_triple(h, r, t)
}
}
#[derive(Debug, Clone)]
pub struct HeterogeneousAdjacency {
pub by_relation: Vec<Vec<(usize, usize)>>,
pub by_node_type: HashMap<(usize, usize), Vec<usize>>,
pub n_relations: usize,
pub n_node_types: usize,
}
impl HeterogeneousAdjacency {
pub fn from_typed_edges(
n_relations: usize,
n_node_types: usize,
node_types: &[usize],
typed_edges: &[(usize, usize, usize)], ) -> Self {
let mut by_relation: Vec<Vec<(usize, usize)>> = vec![Vec::new(); n_relations];
let mut by_node_type: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
for &(src, rel, dst) in typed_edges {
if rel < n_relations {
by_relation[rel].push((src, dst));
}
let src_t = node_types.get(src).copied().unwrap_or(0);
let dst_t = node_types.get(dst).copied().unwrap_or(0);
by_node_type.entry((src_t, dst_t)).or_default().push(src);
}
Self {
by_relation,
by_node_type,
n_relations,
n_node_types,
}
}
pub fn edges_for_relation(&self, r: usize) -> &[(usize, usize)] {
self.by_relation.get(r).map(Vec::as_slice).unwrap_or(&[])
}
pub fn sources_by_type(&self, src_type: usize, dst_type: usize) -> &[usize] {
self.by_node_type
.get(&(src_type, dst_type))
.map(Vec::as_slice)
.unwrap_or(&[])
}
pub fn n_edges(&self) -> usize {
self.by_relation.iter().map(|v| v.len()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_relation_embedding_shape() {
let emb = RelationEmbedding::new(5, 16);
assert_eq!(emb.table.nrows(), 5);
assert_eq!(emb.table.ncols(), 16);
}
#[test]
fn test_relation_embedding_get() {
let emb = RelationEmbedding::new(3, 8);
assert!(emb.get(0).is_some());
assert!(emb.get(2).is_some());
assert!(emb.get(3).is_none());
}
#[test]
fn test_rotate_score_is_finite() {
let scorer = RotatEScoring::new(5, 3, 8);
let s = scorer.score_triple(0, 0, 1);
assert!(s.is_finite(), "RotatE score must be finite");
}
#[test]
fn test_rotate_self_score_is_highest() {
let scorer = RotatEScoring::new(4, 2, 4);
let s_same = scorer.score_triple(0, 0, 0);
let s_diff = scorer.score_triple(0, 0, 1);
assert!(s_same.is_finite());
assert!(s_diff.is_finite());
}
#[test]
fn test_rotate_scorer_trait_object() {
let scorer: Box<dyn KgScorer> = Box::new(RotatEScoring::new(3, 2, 4));
let s = scorer.score(0, 0, 1);
assert!(s.is_finite());
}
#[test]
fn test_heterogeneous_adjacency_by_relation() {
let node_types = vec![0usize, 0, 1, 1];
let edges = vec![
(0usize, 0usize, 2usize), (1, 0, 3), (0, 1, 1), ];
let adj = HeterogeneousAdjacency::from_typed_edges(2, 2, &node_types, &edges);
assert_eq!(adj.by_relation.len(), 2);
assert_eq!(adj.edges_for_relation(0).len(), 2);
assert_eq!(adj.edges_for_relation(1).len(), 1);
}
#[test]
fn test_heterogeneous_adjacency_by_node_type() {
let node_types = vec![0usize, 0, 1, 1];
let edges = vec![(0usize, 0usize, 2usize), (1, 0, 3)];
let adj = HeterogeneousAdjacency::from_typed_edges(1, 2, &node_types, &edges);
let srcs = adj.sources_by_type(0, 1);
assert_eq!(srcs.len(), 2);
}
#[test]
fn test_heterogeneous_adjacency_n_edges() {
let node_types = vec![0usize; 5];
let edges: Vec<(usize, usize, usize)> = (0..4).map(|i| (i, 0, i + 1)).collect();
let adj = HeterogeneousAdjacency::from_typed_edges(1, 1, &node_types, &edges);
assert_eq!(adj.n_edges(), 4);
}
}