1use core::fmt;
7use core::hash::{Hash, Hasher};
8use core::marker::PhantomData;
9
10use crate::edge::EdgeIndex;
11use crate::node::NodeIndex;
12use crate::tensor::dense::DenseTensor;
13use crate::tensor::traits::TensorBase;
14
15#[cfg(feature = "tensor")]
16use crate::tensor::sparse::SparseTensor;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21#[derive(Clone)]
26#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
27pub struct TensorNode<T: TensorBase> {
28 index: NodeIndex,
30 data: T,
32 _marker: PhantomData<T>,
34}
35
36impl<T: TensorBase> TensorNode<T> {
37 pub fn new(index: NodeIndex, data: T) -> Self {
39 Self {
40 index,
41 data,
42 _marker: PhantomData,
43 }
44 }
45
46 pub fn index(&self) -> NodeIndex {
48 self.index
49 }
50
51 pub fn data(&self) -> &T {
53 &self.data
54 }
55
56 pub fn data_mut(&mut self) -> &mut T {
58 &mut self.data
59 }
60
61 pub fn shape(&self) -> &[usize] {
63 self.data.shape()
64 }
65
66 pub fn set_data(&mut self, data: T) {
68 self.data = data;
69 }
70
71 pub fn into_data(self) -> T {
73 self.data
74 }
75}
76
77impl<T: TensorBase> fmt::Debug for TensorNode<T> {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 f.debug_struct("TensorNode")
80 .field("index", &self.index)
81 .field("shape", &self.data.shape())
82 .field("dtype", &self.data.dtype())
83 .finish()
84 }
85}
86
87impl<T: TensorBase> PartialEq for TensorNode<T> {
88 fn eq(&self, other: &Self) -> bool {
89 self.index == other.index
90 }
91}
92
93impl<T: TensorBase> Eq for TensorNode<T> {}
94
95impl<T: TensorBase> Hash for TensorNode<T> {
96 fn hash<H: Hasher>(&self, state: &mut H) {
97 self.index.hash(state);
98 }
99}
100
101#[derive(Clone)]
105#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
106pub struct TensorEdge<E: TensorBase> {
107 index: EdgeIndex,
109 data: E,
111 source: NodeIndex,
113 target: NodeIndex,
115}
116
117impl<E: TensorBase> TensorEdge<E> {
118 pub fn new(index: EdgeIndex, data: E, source: NodeIndex, target: NodeIndex) -> Self {
120 Self {
121 index,
122 data,
123 source,
124 target,
125 }
126 }
127
128 pub fn index(&self) -> EdgeIndex {
130 self.index
131 }
132
133 pub fn data(&self) -> &E {
135 &self.data
136 }
137
138 pub fn data_mut(&mut self) -> &mut E {
140 &mut self.data
141 }
142
143 pub fn source(&self) -> NodeIndex {
145 self.source
146 }
147
148 pub fn target(&self) -> NodeIndex {
150 self.target
151 }
152
153 pub fn endpoints(&self) -> (NodeIndex, NodeIndex) {
155 (self.source, self.target)
156 }
157
158 pub fn shape(&self) -> &[usize] {
160 self.data.shape()
161 }
162
163 pub fn set_data(&mut self, data: E) {
165 self.data = data;
166 }
167}
168
169impl<E: TensorBase> fmt::Debug for TensorEdge<E> {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 f.debug_struct("TensorEdge")
172 .field("index", &self.index)
173 .field(
174 "endpoints",
175 &format!("({:?}, {:?})", self.source, self.target),
176 )
177 .field("shape", &self.data.shape())
178 .field("dtype", &self.data.dtype())
179 .finish()
180 }
181}
182
183impl<E: TensorBase> PartialEq for TensorEdge<E> {
184 fn eq(&self, other: &Self) -> bool {
185 self.index == other.index
186 }
187}
188
189impl<E: TensorBase> Eq for TensorEdge<E> {}
190
191impl<E: TensorBase> Hash for TensorEdge<E> {
192 fn hash<H: Hasher>(&self, state: &mut H) {
193 self.index.hash(state);
194 }
195}
196
197pub type NodeFeatures = TensorNode<DenseTensor>;
201
202pub type EdgeFeatures = TensorEdge<DenseTensor>;
206
207pub type NodeEmbedding = TensorNode<DenseTensor>;
211
212pub type HiddenState = DenseTensor;
216
217pub struct BatchedNodeFeatures<T: TensorBase> {
221 pub graph_indices: Vec<usize>,
223 pub node_indices: Vec<NodeIndex>,
225 pub features: T,
227}
228
229impl<T: TensorBase> BatchedNodeFeatures<T> {
230 pub fn new(graph_indices: Vec<usize>, node_indices: Vec<NodeIndex>, features: T) -> Self {
232 Self {
233 graph_indices,
234 node_indices,
235 features,
236 }
237 }
238
239 pub fn batch_size(&self) -> usize {
241 self.graph_indices.len()
242 }
243
244 pub fn features(&self) -> &T {
246 &self.features
247 }
248
249 pub fn get_sample(&self, sample_idx: usize) -> Option<&T> {
251 if sample_idx < self.graph_indices.len() {
252 Some(&self.features)
253 } else {
254 None
255 }
256 }
257}
258
259pub struct GNMessage<T: TensorBase> {
263 pub source_features: T,
265 pub edge_features: Option<T>,
267 pub target_features: T,
269}
270
271impl<T: TensorBase> GNMessage<T> {
272 pub fn new(source_features: T, edge_features: Option<T>, target_features: T) -> Self {
274 Self {
275 source_features,
276 edge_features,
277 target_features,
278 }
279 }
280
281 pub fn source(&self) -> &T {
283 &self.source_features
284 }
285
286 pub fn edge(&self) -> Option<&T> {
288 self.edge_features.as_ref()
289 }
290
291 pub fn target(&self) -> &T {
293 &self.target_features
294 }
295}
296
297#[cfg(feature = "tensor")]
301pub struct AdjacencyMatrix {
302 pub tensor: SparseTensor,
304 pub num_nodes: usize,
306}
307
308#[cfg(feature = "tensor")]
309impl AdjacencyMatrix {
310 pub fn from_edges(edges: &[(usize, usize, f64)], num_nodes: usize) -> Self {
312 let tensor = SparseTensor::from_edges(edges, [num_nodes, num_nodes]);
313 Self { tensor, num_nodes }
314 }
315
316 pub fn nnz(&self) -> usize {
318 self.tensor.nnz()
319 }
320
321 pub fn to_sparse(&self) -> SparseTensor {
323 self.tensor.clone()
324 }
325
326 pub fn to_dense(&self) -> DenseTensor {
328 self.tensor.to_dense()
329 }
330}
331
332pub struct DegreeMatrix {
334 pub degrees: DenseTensor,
336 pub num_nodes: usize,
338}
339
340#[cfg(feature = "tensor")]
341impl DegreeMatrix {
342 pub fn from_adjacency(adj: &AdjacencyMatrix) -> Self {
344 let degrees = vec![0.0; adj.num_nodes];
345 let mut degrees_tensor = DenseTensor::new(degrees, vec![adj.num_nodes]);
346
347 let coo = adj.tensor.to_coo();
349 for &row in coo.row_indices() {
350 let current = degrees_tensor.get(&[row]).unwrap();
351 degrees_tensor.set(&[row], current + 1.0).unwrap();
352 }
353
354 Self {
355 degrees: degrees_tensor,
356 num_nodes: adj.num_nodes,
357 }
358 }
359
360 pub fn degrees(&self) -> &DenseTensor {
362 &self.degrees
363 }
364
365 pub fn inverse_sqrt(&self, epsilon: f64) -> DenseTensor {
367 let shape = self.degrees.shape().to_vec();
368 let inv_sqrt: Vec<f64> = self.degrees.data()
369 .iter()
370 .map(|&d| if d > epsilon { 1.0 / d.sqrt() } else { 0.0 })
371 .collect();
372 DenseTensor::new(inv_sqrt, shape)
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_tensor_node_creation() {
382 let index = NodeIndex::new(0, 1);
383 let data = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![3]);
384 let node = TensorNode::new(index, data.clone());
385
386 assert_eq!(node.index(), index);
387 assert_eq!(node.data(), &data);
388 assert_eq!(node.shape(), &[3]);
389 }
390
391 #[test]
392 fn test_tensor_edge_creation() {
393 let index = EdgeIndex::new(0, 1);
394 let source = NodeIndex::new(0, 1);
395 let target = NodeIndex::new(1, 1);
396 let data = DenseTensor::scalar(0.5);
397
398 let edge = TensorEdge::new(index, data.clone(), source, target);
399
400 assert_eq!(edge.index(), index);
401 assert_eq!(edge.source(), source);
402 assert_eq!(edge.target(), target);
403 assert_eq!(edge.endpoints(), (source, target));
404 }
405
406 #[test]
407 #[cfg(feature = "tensor")]
408 fn test_adjacency_matrix() {
409 let edges = vec![(0, 1, 1.0), (0, 2, 1.0), (1, 2, 1.0)];
410 let adj = AdjacencyMatrix::from_edges(&edges, 3);
411
412 assert_eq!(adj.num_nodes, 3);
413 assert_eq!(adj.nnz(), 3);
414
415 let dense = adj.to_dense();
416 assert_eq!(dense.shape(), &[3, 3]);
417 assert_eq!(dense.get(&[0, 1]).unwrap(), 1.0);
418 assert_eq!(dense.get(&[0, 2]).unwrap(), 1.0);
419 }
420
421 #[test]
422 #[cfg(feature = "tensor")]
423 fn test_degree_matrix() {
424 let edges = vec![(0, 1, 1.0), (0, 2, 1.0), (1, 2, 1.0)];
425 let adj = AdjacencyMatrix::from_edges(&edges, 3);
426 let degree = DegreeMatrix::from_adjacency(&adj);
427
428 assert_eq!(degree.num_nodes, 3);
429 assert!((degree.degrees().get(&[0]).unwrap() - 2.0).abs() < 1e-10);
433 assert!((degree.degrees().get(&[1]).unwrap() - 1.0).abs() < 1e-10);
434 assert!((degree.degrees().get(&[2]).unwrap() - 0.0).abs() < 1e-10);
435 }
436
437 #[test]
438 fn test_gnn_message() {
439 let src = DenseTensor::new(vec![1.0, 2.0], vec![2]);
440 let edge = DenseTensor::scalar(0.5);
441 let dst = DenseTensor::new(vec![3.0, 4.0], vec![2]);
442
443 let msg = GNMessage::new(src.clone(), Some(edge.clone()), dst.clone());
444
445 assert_eq!(msg.source().data(), &[1.0, 2.0]);
446 assert_eq!(msg.edge().unwrap().data(), &[0.5]);
447 assert_eq!(msg.target().data(), &[3.0, 4.0]);
448 }
449}