use std::collections::HashMap;
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::{Rng, RngExt};
use crate::error::{GraphError, Result};
fn xavier_uniform(rows: usize, cols: usize) -> Array2<f64> {
let mut rng = scirs2_core::random::rng();
let limit = (6.0_f64 / (rows + cols) as f64).sqrt();
Array2::from_shape_fn((rows, cols), |_| rng.random::<f64>() * 2.0 * limit - limit)
}
fn relu2(x: &Array2<f64>) -> Array2<f64> {
x.mapv(|v| v.max(0.0))
}
#[derive(Debug, Clone)]
pub struct RgcnConfig {
pub hidden_dim: usize,
pub n_bases: usize,
pub n_layers: usize,
pub dropout: f64,
pub self_loop: bool,
}
impl Default for RgcnConfig {
fn default() -> Self {
Self {
hidden_dim: 64,
n_bases: 4,
n_layers: 2,
dropout: 0.1,
self_loop: true,
}
}
}
#[derive(Debug, Clone)]
pub struct RgcnBasisDecomposition {
pub basis_matrices: Vec<Array2<f64>>,
pub coefficients: Vec<Vec<f64>>,
pub in_dim: usize,
pub out_dim: usize,
}
impl RgcnBasisDecomposition {
pub fn new(in_dim: usize, out_dim: usize, n_bases: usize, n_relations: usize) -> Result<Self> {
if n_bases == 0 {
return Err(GraphError::InvalidParameter {
param: "n_bases".to_string(),
value: "0".to_string(),
expected: ">= 1".to_string(),
context: "RgcnBasisDecomposition::new".to_string(),
});
}
let basis_matrices: Vec<Array2<f64>> = (0..n_bases)
.map(|_| xavier_uniform(out_dim, in_dim))
.collect();
let mut rng = scirs2_core::random::rng();
let coefficients: Vec<Vec<f64>> = (0..n_relations)
.map(|_| (0..n_bases).map(|_| rng.random::<f64>() * 0.1).collect())
.collect();
Ok(Self {
basis_matrices,
coefficients,
in_dim,
out_dim,
})
}
pub fn effective_weight(&self, relation: usize) -> Result<Array2<f64>> {
let coeffs =
self.coefficients
.get(relation)
.ok_or_else(|| GraphError::InvalidParameter {
param: "relation".to_string(),
value: relation.to_string(),
expected: format!("< {}", self.coefficients.len()),
context: "RgcnBasisDecomposition::effective_weight".to_string(),
})?;
let mut w = Array2::<f64>::zeros((self.out_dim, self.in_dim));
for (b, &coeff) in coeffs.iter().enumerate() {
w = w + coeff * &self.basis_matrices[b];
}
Ok(w)
}
}
#[derive(Debug, Clone)]
pub struct RgcnLayer {
pub basis_decomp: RgcnBasisDecomposition,
pub self_loop_weight: Option<Array2<f64>>,
pub bias: Array1<f64>,
pub n_relations: usize,
pub out_dim: usize,
}
impl RgcnLayer {
pub fn new(
in_dim: usize,
out_dim: usize,
n_relations: usize,
n_bases: usize,
self_loop: bool,
) -> Result<Self> {
let basis_decomp = RgcnBasisDecomposition::new(in_dim, out_dim, n_bases, n_relations)?;
let self_loop_weight = if self_loop {
Some(xavier_uniform(out_dim, in_dim))
} else {
None
};
Ok(Self {
basis_decomp,
self_loop_weight,
bias: Array1::zeros(out_dim),
n_relations,
out_dim,
})
}
pub fn forward(
&self,
node_feats: &Array2<f64>,
adj_by_relation: &[Vec<(usize, usize)>],
) -> Result<Array2<f64>> {
let n_nodes = node_feats.nrows();
let in_dim = node_feats.ncols();
if in_dim != self.basis_decomp.in_dim {
return Err(GraphError::InvalidParameter {
param: "node_feats".to_string(),
value: format!("in_dim={}", in_dim),
expected: format!("in_dim={}", self.basis_decomp.in_dim),
context: "RgcnLayer::forward".to_string(),
});
}
let mut combined = Array2::<f64>::zeros((n_nodes, self.out_dim));
for (r, edges) in adj_by_relation.iter().enumerate() {
if r >= self.n_relations {
break;
}
let w_r = self.basis_decomp.effective_weight(r)?;
let mut in_deg: Vec<usize> = vec![0usize; n_nodes];
for &(_, dst) in edges {
if dst < n_nodes {
in_deg[dst] += 1;
}
}
for &(src, dst) in edges {
if src >= n_nodes || dst >= n_nodes {
continue;
}
let h_j = node_feats.row(src);
let msg = w_r.dot(&h_j);
let deg = in_deg[dst].max(1) as f64;
let mut row = combined.row_mut(dst);
row.zip_mut_with(&msg, |acc, &m| *acc += m / deg);
}
}
if let Some(ref w0) = self.self_loop_weight {
for i in 0..n_nodes {
let h_i = node_feats.row(i);
let self_msg = w0.dot(&h_i);
let mut row = combined.row_mut(i);
row.zip_mut_with(&self_msg, |acc, &v| *acc += v);
}
}
for mut row in combined.rows_mut() {
row.zip_mut_with(&self.bias, |v, &b| *v += b);
}
Ok(relu2(&combined))
}
}
#[derive(Debug, Clone)]
pub struct Rgcn {
pub layers: Vec<RgcnLayer>,
}
impl Rgcn {
pub fn new(in_dim: usize, n_relations: usize, config: &RgcnConfig) -> Result<Self> {
let mut layers = Vec::with_capacity(config.n_layers);
for i in 0..config.n_layers {
let layer_in = if i == 0 { in_dim } else { config.hidden_dim };
let layer = RgcnLayer::new(
layer_in,
config.hidden_dim,
n_relations,
config.n_bases,
config.self_loop,
)?;
layers.push(layer);
}
Ok(Self { layers })
}
pub fn forward(
&self,
node_feats: &Array2<f64>,
adj_by_relation: &[Vec<(usize, usize)>],
) -> Result<Array2<f64>> {
let mut h = node_feats.clone();
for layer in &self.layers {
h = layer.forward(&h, adj_by_relation)?;
}
Ok(h)
}
}
#[derive(Debug, Clone)]
pub struct DistMultScoring {
pub entity_embeds: Array2<f64>,
pub relation_embeds: Array2<f64>,
}
impl DistMultScoring {
pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Self {
Self {
entity_embeds: xavier_uniform(n_entities, dim),
relation_embeds: xavier_uniform(n_relations, dim),
}
}
pub fn score(&self, h: usize, r: usize, t: usize) -> f64 {
let h_emb = self.entity_embeds.row(h);
let r_emb = self.relation_embeds.row(r);
let t_emb = self.entity_embeds.row(t);
h_emb
.iter()
.zip(r_emb.iter())
.zip(t_emb.iter())
.map(|((&hk, &rk), &tk)| hk * rk * tk)
.sum()
}
}
pub trait KgScorer: Send + Sync {
fn score(&self, h: usize, r: usize, t: usize) -> f64;
}
impl KgScorer for DistMultScoring {
fn score(&self, h: usize, r: usize, t: usize) -> f64 {
DistMultScoring::score(self, h, r, t)
}
}
#[derive(Debug, Clone)]
pub struct RgcnLinkPredictor {
pub encoder: Rgcn,
pub decoder: DistMultScoring,
pub node_embeddings: Option<Array2<f64>>,
pub dim: usize,
pub n_relations: usize,
}
impl RgcnLinkPredictor {
pub fn new(
in_dim: usize,
n_entities: usize,
n_relations: usize,
config: &RgcnConfig,
) -> Result<Self> {
let encoder = Rgcn::new(in_dim, n_relations, config)?;
let decoder = DistMultScoring::new(n_entities, n_relations, config.hidden_dim);
Ok(Self {
encoder,
decoder,
node_embeddings: None,
dim: config.hidden_dim,
n_relations,
})
}
pub fn encode(
&mut self,
node_feats: &Array2<f64>,
adj_by_relation: &[Vec<(usize, usize)>],
) -> Result<()> {
let h = self.encoder.forward(node_feats, adj_by_relation)?;
self.node_embeddings = Some(h);
Ok(())
}
pub fn score_triple(&self, h: usize, r: usize, t: usize) -> f64 {
match &self.node_embeddings {
Some(embeds) => {
let h_emb = embeds.row(h);
let r_emb = self.decoder.relation_embeds.row(r);
let t_emb = embeds.row(t);
h_emb
.iter()
.zip(r_emb.iter())
.zip(t_emb.iter())
.map(|((&hk, &rk), &tk)| hk * rk * tk)
.sum()
}
None => self.decoder.score(h, r, t),
}
}
}
impl KgScorer for RgcnLinkPredictor {
fn score(&self, h: usize, r: usize, t: usize) -> f64 {
self.score_triple(h, r, t)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn eye_feats(n: usize, dim: usize) -> Array2<f64> {
let mut m = Array2::<f64>::zeros((n, dim));
for i in 0..n.min(dim) {
m[(i, i)] = 1.0;
}
m
}
fn single_relation_adj(n: usize) -> Vec<Vec<(usize, usize)>> {
let edges: Vec<(usize, usize)> = (0..n).map(|i| (i, (i + 1) % n)).collect();
vec![edges]
}
#[test]
fn test_basis_decomp_single_basis_recovers_weight() {
let mut decomp = RgcnBasisDecomposition::new(4, 4, 1, 1).expect("decomp");
decomp.coefficients[0][0] = 1.0;
let w = decomp.effective_weight(0).expect("w");
let diff = (&w - &decomp.basis_matrices[0]).mapv(|v| v.abs()).sum();
assert!(
diff < 1e-10,
"effective_weight should equal basis[0] when a=1"
);
}
#[test]
fn test_rgcn_layer_output_shape() {
let feats = eye_feats(5, 8);
let adj = single_relation_adj(5);
let layer = RgcnLayer::new(8, 16, 1, 2, true).expect("layer");
let out = layer.forward(&feats, &adj).expect("forward");
assert_eq!(out.nrows(), 5);
assert_eq!(out.ncols(), 16);
}
#[test]
fn test_rgcn_layer_isolated_node_self_loop() {
let feats = eye_feats(3, 4);
let adj: Vec<Vec<(usize, usize)>> = vec![vec![]]; let layer = RgcnLayer::new(4, 4, 1, 1, true).expect("layer");
let out = layer.forward(&feats, &adj).expect("forward");
let row_norm: f64 = out.row(0).iter().map(|x| x * x).sum::<f64>().sqrt();
assert!(row_norm >= 0.0, "isolated node output must be finite");
assert_eq!(out.nrows(), 3);
}
#[test]
fn test_rgcn_stacked_output_shape() {
let config = RgcnConfig {
hidden_dim: 8,
n_bases: 2,
n_layers: 3,
..Default::default()
};
let feats = eye_feats(6, 4);
let adj = single_relation_adj(6);
let rgcn = Rgcn::new(4, 1, &config).expect("rgcn");
let out = rgcn.forward(&feats, &adj).expect("forward");
assert_eq!(out.nrows(), 6);
assert_eq!(out.ncols(), 8);
}
#[test]
fn test_distmult_symmetry() {
let dm = DistMultScoring::new(4, 2, 8);
let s1 = dm.score(0, 0, 1);
let s2 = dm.score(1, 0, 0);
assert!((s1 - s2).abs() < 1e-10, "DistMult should be symmetric");
}
#[test]
fn test_rgcn_link_predictor_encode() {
let config = RgcnConfig::default();
let mut predictor = RgcnLinkPredictor::new(4, 5, 2, &config).expect("predictor");
let feats = eye_feats(5, 4);
let adj: Vec<Vec<(usize, usize)>> = vec![vec![(0, 1), (1, 2)], vec![(2, 3)]];
predictor.encode(&feats, &adj).expect("encode");
assert!(predictor.node_embeddings.is_some());
let embeds = predictor.node_embeddings.as_ref().expect("embeds");
assert_eq!(embeds.nrows(), 5);
assert_eq!(embeds.ncols(), config.hidden_dim);
let s = predictor.score_triple(0, 0, 1);
assert!(s.is_finite());
}
}