use super::time_encoding::{concat, matvec, scaled_dot_product, softmax, xavier_init, relu_vec};
use super::types::{TgatConfig, TgnnGraph, TemporalPrediction};
use crate::error::{GraphError, Result};
#[derive(Debug, Clone)]
pub struct TgatLayer {
pub num_heads: usize,
pub head_dim: usize,
pub time_dim: usize,
pub input_dim: usize,
w_q: Vec<Vec<Vec<f64>>>,
w_k: Vec<Vec<Vec<f64>>>,
w_v: Vec<Vec<Vec<f64>>>,
w_o: Vec<Vec<f64>>,
b_o: Vec<f64>,
pub output_dim: usize,
}
impl TgatLayer {
pub fn new(
node_feat_dim: usize,
time_dim: usize,
num_heads: usize,
head_dim: usize,
seed: u64,
) -> Result<Self> {
if num_heads == 0 || head_dim == 0 {
return Err(GraphError::InvalidParameter {
param: "num_heads/head_dim".to_string(),
value: format!("{}/{}", num_heads, head_dim),
expected: "both > 0".to_string(),
context: "TgatLayer::new".to_string(),
});
}
let input_dim = node_feat_dim + time_dim;
let output_dim = num_heads * head_dim;
let mut w_q = Vec::with_capacity(num_heads);
let mut w_k = Vec::with_capacity(num_heads);
let mut w_v = Vec::with_capacity(num_heads);
for h in 0..num_heads {
w_q.push(xavier_init(head_dim, input_dim, seed.wrapping_add(h as u64)));
w_k.push(xavier_init(head_dim, input_dim, seed.wrapping_add(1000 + h as u64)));
w_v.push(xavier_init(head_dim, input_dim, seed.wrapping_add(2000 + h as u64)));
}
let w_o = xavier_init(output_dim, output_dim, seed.wrapping_add(3000));
let b_o = vec![0.0f64; output_dim];
Ok(TgatLayer {
num_heads,
head_dim,
time_dim,
input_dim,
w_q,
w_k,
w_v,
w_o,
b_o,
output_dim,
})
}
pub fn forward_node(
&self,
h_self: &[f64],
neighbors: &[(Vec<f64>, f64)],
query_time: f64,
time_enc: &super::time_encoding::TimeEncode,
) -> Vec<f64> {
let phi_self = time_enc.encode(0.0);
let q_input = concat(h_self, &phi_self);
if neighbors.is_empty() {
return vec![0.0f64; self.output_dim];
}
let kv_inputs: Vec<Vec<f64>> = neighbors
.iter()
.map(|(h_nbr, t_nbr)| {
let phi = time_enc.encode_delta(query_time, *t_nbr);
concat(h_nbr, &phi)
})
.collect();
let mut head_outputs: Vec<f64> = Vec::with_capacity(self.output_dim);
for head in 0..self.num_heads {
let q = matvec(&self.w_q[head], &q_input);
let keys: Vec<Vec<f64>> = kv_inputs
.iter()
.map(|kv| matvec(&self.w_k[head], kv))
.collect();
let values: Vec<Vec<f64>> = kv_inputs
.iter()
.map(|kv| matvec(&self.w_v[head], kv))
.collect();
let logits = scaled_dot_product(&q, &keys);
let alphas = softmax(&logits);
let mut attended = vec![0.0f64; self.head_dim];
for (alpha, val) in alphas.iter().zip(values.iter()) {
for (a, v) in attended.iter_mut().zip(val.iter()) {
*a += alpha * v;
}
}
head_outputs.extend(attended);
}
let projected = matvec(&self.w_o, &head_outputs);
let mut out: Vec<f64> = projected
.iter()
.zip(self.b_o.iter())
.map(|(p, b)| p + b)
.collect();
out = relu_vec(&out);
if h_self.len() == out.len() {
for (o, s) in out.iter_mut().zip(h_self.iter()) {
*o += s;
}
}
out
}
}
#[derive(Debug, Clone)]
pub struct TgatModel {
pub layers: Vec<TgatLayer>,
pub time_enc: super::time_encoding::TimeEncode,
pub config: TgatConfig,
pub output_dim: usize,
}
impl TgatModel {
pub fn new(config: &TgatConfig, node_feat_dim: usize) -> Result<Self> {
let eff_feat_dim = if node_feat_dim == 0 {
config.head_dim
} else {
node_feat_dim
};
let time_enc = super::time_encoding::TimeEncode::new(config.time_dim)?;
let mut layers = Vec::with_capacity(config.num_layers);
let output_dim = config.num_heads * config.head_dim;
let first_layer = TgatLayer::new(
eff_feat_dim,
config.time_dim,
config.num_heads,
config.head_dim,
12345,
)?;
let first_output = first_layer.output_dim;
layers.push(first_layer);
for layer_idx in 1..config.num_layers {
let layer = TgatLayer::new(
first_output, config.time_dim,
config.num_heads,
config.head_dim,
12345 + layer_idx as u64 * 999,
)?;
layers.push(layer);
}
Ok(TgatModel {
layers,
time_enc,
config: config.clone(),
output_dim,
})
}
pub fn forward(&self, graph: &TgnnGraph, query_time: f64) -> Result<Vec<Vec<f64>>> {
let n = graph.n_nodes;
if n == 0 {
return Ok(Vec::new());
}
let eff_feat_dim = if graph.node_feat_dim == 0 {
self.config.head_dim
} else {
graph.node_feat_dim
};
let mut current_embeddings: Vec<Vec<f64>> = (0..n)
.map(|i| {
let feat = graph.node_feat(i);
if feat.is_empty() {
vec![0.0f64; eff_feat_dim]
} else if feat.len() == eff_feat_dim {
feat.to_vec()
} else {
let mut v = vec![0.0f64; eff_feat_dim];
let copy_len = feat.len().min(eff_feat_dim);
v[..copy_len].copy_from_slice(&feat[..copy_len]);
v
}
})
.collect();
for layer in &self.layers {
let prev_embeddings = current_embeddings.clone();
let mut next_embeddings = Vec::with_capacity(n);
for i in 0..n {
let nbr_tuples = graph.neighbors_before(i, query_time);
let neighbors: Vec<(Vec<f64>, f64)> = nbr_tuples
.iter()
.map(|(j, t_edge, _edge_feat)| {
let h_nbr = prev_embeddings
.get(*j)
.cloned()
.unwrap_or_else(|| vec![0.0f64; prev_embeddings[0].len()]);
(h_nbr, *t_edge)
})
.collect();
let h_self = &prev_embeddings[i];
let new_h = layer.forward_node(h_self, &neighbors, query_time, &self.time_enc);
next_embeddings.push(new_h);
}
current_embeddings = next_embeddings;
}
Ok(current_embeddings)
}
pub fn predict(&self, graph: &TgnnGraph, query_time: f64) -> Result<Vec<TemporalPrediction>> {
let embeddings = self.forward(graph, query_time)?;
Ok(embeddings
.into_iter()
.enumerate()
.map(|(i, emb)| TemporalPrediction::new(i, emb, query_time))
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::types::{TgatConfig, TgnnEdge, TgnnGraph};
use super::super::time_encoding::TimeEncode;
fn simple_graph() -> TgnnGraph {
let mut g = TgnnGraph::with_zero_features(5, 4);
g.add_edge(TgnnEdge::no_feat(0, 1, 1.0));
g.add_edge(TgnnEdge::no_feat(1, 2, 2.0));
g.add_edge(TgnnEdge::no_feat(2, 3, 3.0));
g.add_edge(TgnnEdge::no_feat(3, 4, 4.0));
g.add_edge(TgnnEdge::no_feat(0, 4, 15.0));
g
}
#[test]
fn test_tgat_output_shape() {
let config = TgatConfig {
num_heads: 2,
time_dim: 8,
head_dim: 8,
num_layers: 1,
dropout: 0.0,
};
let model = TgatModel::new(&config, 4).expect("model creation");
let graph = simple_graph();
let embeddings = model.forward(&graph, 10.0).expect("forward pass");
assert_eq!(embeddings.len(), 5, "must produce one embedding per node");
let expected_dim = config.num_heads * config.head_dim;
for emb in &embeddings {
assert_eq!(emb.len(), expected_dim, "each embedding has wrong dim");
}
}
#[test]
fn test_tgat_causal_masking() {
let config = TgatConfig {
num_heads: 1,
time_dim: 8,
head_dim: 8,
num_layers: 1,
dropout: 0.0,
};
let model = TgatModel::new(&config, 4).expect("model");
let mut g_with_future = TgnnGraph::with_zero_features(5, 4);
g_with_future.add_edge(TgnnEdge::no_feat(0, 1, 1.0));
g_with_future.add_edge(TgnnEdge::no_feat(0, 4, 15.0));
let mut g_no_future = TgnnGraph::with_zero_features(5, 4);
g_no_future.add_edge(TgnnEdge::no_feat(0, 1, 1.0));
let emb_future = model.forward(&g_with_future, 10.0).expect("forward");
let emb_no_future = model.forward(&g_no_future, 10.0).expect("forward");
for (ef, en) in emb_future.iter().zip(emb_no_future.iter()) {
for (a, b) in ef.iter().zip(en.iter()) {
assert!(
(a - b).abs() < 1e-10,
"future edge must not influence embeddings"
);
}
}
}
#[test]
fn test_tgat_attention_softmax_sums_one() {
let layer = TgatLayer::new(4, 8, 1, 8, 42).expect("layer");
let time_enc = TimeEncode::new(8).expect("enc");
let h_self = vec![1.0, 0.0, 0.0, 0.0];
let neighbors = vec![
(vec![0.0, 1.0, 0.0, 0.0], 1.0_f64),
(vec![0.0, 0.0, 1.0, 0.0], 2.0_f64),
(vec![0.0, 0.0, 0.0, 1.0], 3.0_f64),
];
let phi_self = time_enc.encode(0.0);
let q_input = concat(&h_self, &phi_self);
let q = matvec(&layer.w_q[0], &q_input);
let keys: Vec<Vec<f64>> = neighbors
.iter()
.map(|(h_nbr, t_nbr)| {
let phi = time_enc.encode_delta(10.0, *t_nbr);
let kv = concat(h_nbr, &phi);
matvec(&layer.w_k[0], &kv)
})
.collect();
let logits = scaled_dot_product(&q, &keys);
let alphas = softmax(&logits);
let sum: f64 = alphas.iter().sum();
assert!((sum - 1.0).abs() < 1e-10, "attention weights must sum to 1, got {}", sum);
for &a in &alphas {
assert!(a >= 0.0, "attention weight must be non-negative");
}
}
#[test]
fn test_tgat_with_no_neighbors() {
let config = TgatConfig {
num_heads: 2,
time_dim: 8,
head_dim: 8,
num_layers: 1,
dropout: 0.0,
};
let model = TgatModel::new(&config, 4).expect("model");
let mut g = TgnnGraph::with_zero_features(3, 4);
g.add_edge(TgnnEdge::no_feat(0, 1, 5.0));
let embeddings = model.forward(&g, 0.5).expect("forward");
for emb in &embeddings {
let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
assert!(
norm < 1e-10,
"node with no neighbors and zero features should produce ~zero embedding, got norm={}",
norm
);
}
}
#[test]
fn test_tgat_multi_head_concat() {
let config = TgatConfig {
num_heads: 4,
time_dim: 8,
head_dim: 6,
num_layers: 1,
dropout: 0.0,
};
let model = TgatModel::new(&config, 4).expect("model");
let graph = simple_graph();
let embeddings = model.forward(&graph, 5.0).expect("forward");
let expected_dim = 4 * 6; for emb in &embeddings {
assert_eq!(emb.len(), expected_dim, "multi-head concat size wrong");
}
}
#[test]
fn test_tgat_two_layers() {
let config = TgatConfig {
num_heads: 2,
time_dim: 8,
head_dim: 8,
num_layers: 2,
dropout: 0.0,
};
let model = TgatModel::new(&config, 4).expect("model");
assert_eq!(model.layers.len(), 2);
let graph = simple_graph();
let embeddings = model.forward(&graph, 10.0).expect("2-layer forward");
assert_eq!(embeddings.len(), 5);
}
}