use core::fmt;
use core::hash::{Hash, Hasher};
use core::marker::PhantomData;
use crate::edge::EdgeIndex;
use crate::node::NodeIndex;
use crate::tensor::dense::DenseTensor;
use crate::tensor::traits::TensorBase;
#[cfg(feature = "tensor")]
use crate::tensor::sparse::SparseTensor;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct TensorNode<T: TensorBase> {
index: NodeIndex,
data: T,
_marker: PhantomData<T>,
}
impl<T: TensorBase> TensorNode<T> {
pub fn new(index: NodeIndex, data: T) -> Self {
Self {
index,
data,
_marker: PhantomData,
}
}
pub fn index(&self) -> NodeIndex {
self.index
}
pub fn data(&self) -> &T {
&self.data
}
pub fn data_mut(&mut self) -> &mut T {
&mut self.data
}
pub fn shape(&self) -> &[usize] {
self.data.shape()
}
pub fn set_data(&mut self, data: T) {
self.data = data;
}
pub fn into_data(self) -> T {
self.data
}
}
impl<T: TensorBase> fmt::Debug for TensorNode<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TensorNode")
.field("index", &self.index)
.field("shape", &self.data.shape())
.field("dtype", &self.data.dtype())
.finish()
}
}
impl<T: TensorBase> PartialEq for TensorNode<T> {
fn eq(&self, other: &Self) -> bool {
self.index == other.index
}
}
impl<T: TensorBase> Eq for TensorNode<T> {}
impl<T: TensorBase> Hash for TensorNode<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.index.hash(state);
}
}
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct TensorEdge<E: TensorBase> {
index: EdgeIndex,
data: E,
source: NodeIndex,
target: NodeIndex,
}
impl<E: TensorBase> TensorEdge<E> {
pub fn new(index: EdgeIndex, data: E, source: NodeIndex, target: NodeIndex) -> Self {
Self {
index,
data,
source,
target,
}
}
pub fn index(&self) -> EdgeIndex {
self.index
}
pub fn data(&self) -> &E {
&self.data
}
pub fn data_mut(&mut self) -> &mut E {
&mut self.data
}
pub fn source(&self) -> NodeIndex {
self.source
}
pub fn target(&self) -> NodeIndex {
self.target
}
pub fn endpoints(&self) -> (NodeIndex, NodeIndex) {
(self.source, self.target)
}
pub fn shape(&self) -> &[usize] {
self.data.shape()
}
pub fn set_data(&mut self, data: E) {
self.data = data;
}
}
impl<E: TensorBase> fmt::Debug for TensorEdge<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TensorEdge")
.field("index", &self.index)
.field(
"endpoints",
&format!("({:?}, {:?})", self.source, self.target),
)
.field("shape", &self.data.shape())
.field("dtype", &self.data.dtype())
.finish()
}
}
impl<E: TensorBase> PartialEq for TensorEdge<E> {
fn eq(&self, other: &Self) -> bool {
self.index == other.index
}
}
impl<E: TensorBase> Eq for TensorEdge<E> {}
impl<E: TensorBase> Hash for TensorEdge<E> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.index.hash(state);
}
}
pub type NodeFeatures = TensorNode<DenseTensor>;
pub type EdgeFeatures = TensorEdge<DenseTensor>;
pub type NodeEmbedding = TensorNode<DenseTensor>;
pub type HiddenState = DenseTensor;
pub struct BatchedNodeFeatures<T: TensorBase> {
pub graph_indices: Vec<usize>,
pub node_indices: Vec<NodeIndex>,
pub features: T,
}
impl<T: TensorBase> BatchedNodeFeatures<T> {
pub fn new(graph_indices: Vec<usize>, node_indices: Vec<NodeIndex>, features: T) -> Self {
Self {
graph_indices,
node_indices,
features,
}
}
pub fn batch_size(&self) -> usize {
self.graph_indices.len()
}
pub fn features(&self) -> &T {
&self.features
}
pub fn get_sample(&self, sample_idx: usize) -> Option<&T> {
if sample_idx < self.graph_indices.len() {
Some(&self.features)
} else {
None
}
}
}
pub struct GNMessage<T: TensorBase> {
pub source_features: T,
pub edge_features: Option<T>,
pub target_features: T,
}
impl<T: TensorBase> GNMessage<T> {
pub fn new(source_features: T, edge_features: Option<T>, target_features: T) -> Self {
Self {
source_features,
edge_features,
target_features,
}
}
pub fn source(&self) -> &T {
&self.source_features
}
pub fn edge(&self) -> Option<&T> {
self.edge_features.as_ref()
}
pub fn target(&self) -> &T {
&self.target_features
}
}
#[cfg(feature = "tensor")]
pub struct AdjacencyMatrix {
pub tensor: SparseTensor,
pub num_nodes: usize,
}
#[cfg(feature = "tensor")]
impl AdjacencyMatrix {
pub fn from_edges(edges: &[(usize, usize, f64)], num_nodes: usize) -> Self {
let tensor = SparseTensor::from_edges(edges, [num_nodes, num_nodes]);
Self { tensor, num_nodes }
}
pub fn nnz(&self) -> usize {
self.tensor.nnz()
}
pub fn to_sparse(&self) -> SparseTensor {
self.tensor.clone()
}
pub fn to_dense(&self) -> DenseTensor {
self.tensor.to_dense()
}
}
pub struct DegreeMatrix {
pub degrees: DenseTensor,
pub num_nodes: usize,
}
#[cfg(feature = "tensor")]
impl DegreeMatrix {
pub fn from_adjacency(adj: &AdjacencyMatrix) -> Self {
let degrees = vec![0.0; adj.num_nodes];
let mut degrees_tensor = DenseTensor::new(degrees, vec![adj.num_nodes]);
let coo = adj.tensor.to_coo();
for &row in coo.row_indices() {
let current = degrees_tensor.get(&[row]).unwrap();
degrees_tensor.set(&[row], current + 1.0).unwrap();
}
Self {
degrees: degrees_tensor,
num_nodes: adj.num_nodes,
}
}
pub fn degrees(&self) -> &DenseTensor {
&self.degrees
}
pub fn inverse_sqrt(&self, epsilon: f64) -> DenseTensor {
let shape = self.degrees.shape().to_vec();
let inv_sqrt: Vec<f64> = self.degrees.data()
.iter()
.map(|&d| if d > epsilon { 1.0 / d.sqrt() } else { 0.0 })
.collect();
DenseTensor::new(inv_sqrt, shape)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_node_creation() {
let index = NodeIndex::new(0, 1);
let data = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![3]);
let node = TensorNode::new(index, data.clone());
assert_eq!(node.index(), index);
assert_eq!(node.data(), &data);
assert_eq!(node.shape(), &[3]);
}
#[test]
fn test_tensor_edge_creation() {
let index = EdgeIndex::new(0, 1);
let source = NodeIndex::new(0, 1);
let target = NodeIndex::new(1, 1);
let data = DenseTensor::scalar(0.5);
let edge = TensorEdge::new(index, data.clone(), source, target);
assert_eq!(edge.index(), index);
assert_eq!(edge.source(), source);
assert_eq!(edge.target(), target);
assert_eq!(edge.endpoints(), (source, target));
}
#[test]
#[cfg(feature = "tensor")]
fn test_adjacency_matrix() {
let edges = vec![(0, 1, 1.0), (0, 2, 1.0), (1, 2, 1.0)];
let adj = AdjacencyMatrix::from_edges(&edges, 3);
assert_eq!(adj.num_nodes, 3);
assert_eq!(adj.nnz(), 3);
let dense = adj.to_dense();
assert_eq!(dense.shape(), &[3, 3]);
assert_eq!(dense.get(&[0, 1]).unwrap(), 1.0);
assert_eq!(dense.get(&[0, 2]).unwrap(), 1.0);
}
#[test]
#[cfg(feature = "tensor")]
fn test_degree_matrix() {
let edges = vec![(0, 1, 1.0), (0, 2, 1.0), (1, 2, 1.0)];
let adj = AdjacencyMatrix::from_edges(&edges, 3);
let degree = DegreeMatrix::from_adjacency(&adj);
assert_eq!(degree.num_nodes, 3);
assert!((degree.degrees().get(&[0]).unwrap() - 2.0).abs() < 1e-10);
assert!((degree.degrees().get(&[1]).unwrap() - 1.0).abs() < 1e-10);
assert!((degree.degrees().get(&[2]).unwrap() - 0.0).abs() < 1e-10);
}
#[test]
fn test_gnn_message() {
let src = DenseTensor::new(vec![1.0, 2.0], vec![2]);
let edge = DenseTensor::scalar(0.5);
let dst = DenseTensor::new(vec![3.0, 4.0], vec![2]);
let msg = GNMessage::new(src.clone(), Some(edge.clone()), dst.clone());
assert_eq!(msg.source().data(), &[1.0, 2.0]);
assert_eq!(msg.edge().unwrap().data(), &[0.5]);
assert_eq!(msg.target().data(), &[3.0, 4.0]);
}
}