Skip to main content

god_graph/tensor/
graph_tensor.rs

1//! Graph-Tensor Integration: Specialized implementations for seamless graph-tensor conversion
2//!
3//! This module provides:
4//! - `TensorGraph`: Graph with native tensor node/edge features
5//! - `GraphToTensor`: Convert traditional graphs to tensor representations
6//! - `TensorToGraph`: Convert tensors back to graph structures
7//! - Adjacency matrix extraction and reconstruction
8//! - Feature matrix extraction for GNN workflows
9
10use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
11use crate::graph::Graph;
12use crate::tensor::error::TensorError;
13use crate::tensor::traits::TensorBase;
14use crate::tensor::DenseTensor;
15
16#[cfg(feature = "tensor")]
17use crate::tensor::{COOTensor, CSRTensor};
18
19/// Adjacency matrix representation for graph neural networks
20#[derive(Debug, Clone)]
21#[cfg(feature = "tensor")]
22pub struct GraphAdjacencyMatrix {
23    /// Sparse adjacency matrix in CSR format
24    csr: CSRTensor,
25    /// Number of nodes
26    pub num_nodes: usize,
27    /// Number of edges
28    pub num_edges: usize,
29    /// Whether the graph is directed
30    pub is_directed: bool,
31}
32
33#[cfg(feature = "tensor")]
34impl GraphAdjacencyMatrix {
35    /// Create adjacency matrix from edge list
36    pub fn from_edge_list(
37        edges: &[(usize, usize)],
38        num_nodes: usize,
39        is_directed: bool,
40    ) -> Result<Self, TensorError> {
41        if edges.is_empty() {
42            return Ok(Self {
43                csr: CSRTensor::new(
44                    vec![0; num_nodes + 1],
45                    vec![],
46                    DenseTensor::zeros(vec![0]),
47                    [num_nodes, num_nodes],
48                ),
49                num_nodes,
50                num_edges: 0,
51                is_directed,
52            });
53        }
54
55        // Build CSR format
56        let mut row_offsets = vec![0usize; num_nodes + 1];
57        let mut col_indices = Vec::with_capacity(edges.len());
58        let mut values_data = Vec::with_capacity(edges.len());
59
60        // Count edges per row
61        for &(src, _) in edges {
62            if src < num_nodes {
63                row_offsets[src + 1] += 1;
64            }
65        }
66
67        // Cumulative sum
68        for i in 1..=num_nodes {
69            row_offsets[i] += row_offsets[i - 1];
70        }
71
72        // Fill column indices and values
73        let mut row_pos = row_offsets[..num_nodes].to_vec();
74        for &(src, dst) in edges {
75            if src < num_nodes && dst < num_nodes {
76                let _pos = row_pos[src];
77                col_indices.push(dst);
78                values_data.push(1.0);
79                row_pos[src] += 1;
80            }
81        }
82
83        let values = DenseTensor::new(values_data, vec![col_indices.len()]);
84        let csr = CSRTensor::new(row_offsets, col_indices, values, [num_nodes, num_nodes]);
85
86        Ok(Self {
87            csr,
88            num_nodes,
89            num_edges: edges.len(),
90            is_directed,
91        })
92    }
93
94    /// Convert to COO format
95    #[cfg(feature = "tensor")]
96    pub fn to_coo(&self) -> COOTensor {
97        use crate::tensor::SparseTensor;
98        let sparse = SparseTensor::CSR(self.csr.clone());
99        sparse.to_coo()
100    }
101
102    /// Get sparse tensor representation
103    #[cfg(feature = "tensor")]
104    pub fn as_sparse_tensor(&self) -> &CSRTensor {
105        &self.csr
106    }
107
108    /// Compute normalized adjacency matrix (for GCN)
109    ///
110    /// Returns: D^(-1/2) * (A + I) * D^(-1/2)
111    /// where D is the degree matrix and I is the identity matrix
112    #[cfg(feature = "tensor")]
113    pub fn normalized_with_self_loops(&self) -> Result<Self, TensorError> {
114        let n = self.num_nodes;
115
116        // Add self-loops
117        let mut edges = Vec::new();
118
119        // Extract existing edges
120        for i in 0..n {
121            let start = self.csr.row_offsets()[i];
122            let end = self.csr.row_offsets()[i + 1];
123            for j in start..end {
124                let col = self.csr.col_indices()[j];
125                edges.push((i, col));
126            }
127            // Add self-loop
128            edges.push((i, i));
129        }
130
131        Self::from_edge_list(&edges, n, self.is_directed)
132    }
133
134    /// Compute degree matrix
135    #[cfg(feature = "tensor")]
136    pub fn degree_matrix(&self) -> DenseTensor {
137        let n = self.num_nodes;
138        let mut degrees = vec![0.0; n];
139
140        for (i, degree) in degrees.iter_mut().enumerate() {
141            let start = self.csr.row_offsets()[i];
142            let end = self.csr.row_offsets()[i + 1];
143            *degree = (end - start) as f64;
144        }
145
146        DenseTensor::from_vec(degrees, vec![n])
147    }
148
149    /// Compute inverse degree matrix (for normalization)
150    #[cfg(feature = "tensor")]
151    pub fn inverse_degree_matrix(&self) -> DenseTensor {
152        let n = self.num_nodes;
153        let mut inv_degrees = vec![0.0; n];
154
155        for (i, inv_degree) in inv_degrees.iter_mut().enumerate() {
156            let start = self.csr.row_offsets()[i];
157            let end = self.csr.row_offsets()[i + 1];
158            let degree = (end - start) as f64;
159            *inv_degree = if degree > 0.0 { 1.0 / degree } else { 0.0 };
160        }
161
162        DenseTensor::from_vec(inv_degrees, vec![n])
163    }
164}
165
166/// Feature extractor for converting graphs to tensor representations
167pub struct GraphFeatureExtractor<'a, T, E> {
168    graph: &'a Graph<T, E>,
169}
170
171impl<'a, T, E> GraphFeatureExtractor<'a, T, E>
172where
173    T: Clone,
174    E: Clone,
175{
176    /// Create new extractor from graph
177    pub fn new(graph: &'a Graph<T, E>) -> Self {
178        Self { graph }
179    }
180
181    /// Extract node features as dense tensor
182    ///
183    /// Each node's data is treated as a scalar feature
184    pub fn extract_node_features_scalar<F>(&self, map_fn: F) -> Result<DenseTensor, TensorError>
185    where
186        F: Fn(&T) -> f64,
187    {
188        let n = self.graph.node_count();
189        let mut features = Vec::with_capacity(n);
190
191        for node_idx in self.graph.nodes() {
192            let data = node_idx.data();
193            features.push(map_fn(data));
194        }
195
196        Ok(DenseTensor::from_vec(features, vec![n, 1]))
197    }
198
199    /// Extract node features as 2D tensor (nodes x features)
200    ///
201    /// Requires node data to be convertible to feature vectors
202    pub fn extract_node_features<F>(
203        &self,
204        map_fn: F,
205        num_features: usize,
206    ) -> Result<DenseTensor, TensorError>
207    where
208        F: for<'b> Fn(&'b T) -> &'b [f64],
209    {
210        let n = self.graph.node_count();
211        let mut features = Vec::with_capacity(n * num_features);
212
213        for node_idx in self.graph.nodes() {
214            let data = node_idx.data();
215            let feat = map_fn(data);
216            features.extend_from_slice(feat);
217        }
218
219        Ok(DenseTensor::from_vec(features, vec![n, num_features]))
220    }
221
222    /// Extract edge features as tensor
223    pub fn extract_edge_features<F>(&self, map_fn: F) -> Result<DenseTensor, TensorError>
224    where
225        F: Fn(&E) -> f64,
226    {
227        let m = self.graph.edge_count();
228        let mut features = Vec::with_capacity(m);
229
230        for edge_idx in self.graph.edges() {
231            let data = edge_idx.data();
232            features.push(map_fn(data));
233        }
234
235        Ok(DenseTensor::from_vec(features, vec![m, 1]))
236    }
237
238    /// Extract adjacency matrix as sparse tensor
239    #[cfg(feature = "tensor")]
240    pub fn extract_adjacency(&self) -> Result<GraphAdjacencyMatrix, TensorError> {
241        let mut edges: Vec<(usize, usize)> = Vec::new();
242
243        for node_idx in self.graph.nodes() {
244            let src = node_idx.index().index();
245            for neighbor in self.graph.neighbors(node_idx.index()) {
246                let dst = neighbor.index();
247                edges.push((src, dst));
248            }
249        }
250
251        GraphAdjacencyMatrix::from_edge_list(
252            &edges,
253            self.graph.node_count(),
254            true, // Assume directed for adjacency extraction
255        )
256    }
257
258    /// Extract complete graph as tensor representation
259    #[cfg(feature = "tensor")]
260    pub fn extract_all(
261        &self,
262        num_node_features: usize,
263    ) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
264    where
265        T: AsRef<[f64]> + Clone,
266        E: Clone,
267    {
268        let node_features =
269            self.extract_node_features(|data: &T| data.as_ref(), num_node_features)?;
270        let adjacency = self.extract_adjacency()?;
271
272        Ok((node_features, adjacency))
273    }
274}
275
276/// Reconstruct graph from tensor representations
277#[allow(dead_code)]
278pub struct GraphReconstructor {
279    directed: bool,
280}
281
282impl GraphReconstructor {
283    /// Create new reconstructor
284    pub fn new(directed: bool) -> Self {
285        Self { directed }
286    }
287
288    /// Reconstruct graph from adjacency matrix
289    #[cfg(feature = "tensor")]
290    pub fn from_adjacency<T, E>(
291        &self,
292        adjacency: &GraphAdjacencyMatrix,
293        mut node_factory: impl FnMut(usize) -> T,
294        mut edge_factory: impl FnMut(usize, usize, f64) -> E,
295    ) -> Result<Graph<T, E>, TensorError>
296    where
297        T: Clone,
298        E: Clone,
299    {
300        let mut graph = if self.directed {
301            Graph::<T, E>::directed()
302        } else {
303            Graph::<T, E>::undirected()
304        };
305
306        let n = adjacency.num_nodes;
307        let mut node_indices = Vec::with_capacity(n);
308
309        // Create nodes
310        for i in 0..n {
311            let node = node_factory(i);
312            let idx = graph.add_node(node).map_err(|e| TensorError::SliceError {
313                description: format!("Failed to add node: {:?}", e),
314            })?;
315            node_indices.push(idx);
316        }
317
318        // Create edges from CSR
319        let csr = adjacency.as_sparse_tensor();
320
321        for src in 0..n {
322            let start = csr.row_offsets()[src];
323            let end = csr.row_offsets()[src + 1];
324
325            for j in start..end {
326                let dst = csr.col_indices()[j];
327                let weight = csr.values().data()[j];
328
329                if let (Some(src_idx), Some(dst_idx)) = (
330                    node_indices.get(src).copied(),
331                    node_indices.get(dst).copied(),
332                ) {
333                    let edge_data = edge_factory(src, dst, weight);
334                    let _ = graph.add_edge(src_idx, dst_idx, edge_data);
335                }
336            }
337        }
338
339        Ok(graph)
340    }
341
342    /// Reconstruct graph from COO tensor
343    #[cfg(feature = "tensor")]
344    pub fn from_coo<T, E>(
345        &self,
346        coo: &COOTensor,
347        node_factory: impl FnMut(usize) -> T,
348        edge_factory: impl FnMut(usize, usize, f64) -> E,
349    ) -> Result<Graph<T, E>, TensorError>
350    where
351        T: Clone,
352        E: Clone,
353    {
354        // Convert COO to edge list
355        let row_indices = coo.row_indices();
356        let col_indices = coo.col_indices();
357        let edges: Vec<(usize, usize)> = row_indices
358            .iter()
359            .zip(col_indices.iter())
360            .map(|(&r, &c)| (r, c))
361            .collect();
362
363        let shape = coo.shape_array();
364        let adjacency = GraphAdjacencyMatrix::from_edge_list(&edges, shape[0], self.directed)?;
365
366        self.from_adjacency(&adjacency, node_factory, edge_factory)
367    }
368}
369
370/// Extension trait for Graph to add tensor conversion methods
371#[cfg(feature = "tensor")]
372pub trait GraphTensorExt<T, E> {
373    /// Convert graph to tensor representation
374    fn to_tensor_representation(&self) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
375    where
376        T: AsRef<[f64]> + Clone,
377        E: Clone;
378
379    /// Get adjacency matrix as sparse tensor
380    fn adjacency_matrix(&self) -> Result<GraphAdjacencyMatrix, TensorError>;
381
382    /// Extract node features as tensor
383    fn node_features(&self, num_features: usize) -> Result<DenseTensor, TensorError>
384    where
385        T: AsRef<[f64]> + Clone;
386
387    /// Create feature extractor
388    fn feature_extractor(&self) -> GraphFeatureExtractor<'_, T, E>;
389}
390
391#[cfg(feature = "tensor")]
392impl<T, E> GraphTensorExt<T, E> for Graph<T, E>
393where
394    T: Clone,
395    E: Clone,
396{
397    fn to_tensor_representation(&self) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
398    where
399        T: AsRef<[f64]> + Clone,
400        E: Clone,
401    {
402        let extractor = GraphFeatureExtractor::new(self);
403        let num_features = if let Some(first_node) = self.nodes().next() {
404            first_node.data().as_ref().len()
405        } else {
406            1
407        };
408
409        extractor.extract_all(num_features)
410    }
411
412    fn adjacency_matrix(&self) -> Result<GraphAdjacencyMatrix, TensorError> {
413        let extractor = GraphFeatureExtractor::new(self);
414        extractor.extract_adjacency()
415    }
416
417    fn node_features(&self, num_features: usize) -> Result<DenseTensor, TensorError>
418    where
419        T: AsRef<[f64]> + Clone,
420    {
421        let extractor = GraphFeatureExtractor::new(self);
422        extractor.extract_node_features(|data: &T| data.as_ref(), num_features)
423    }
424
425    fn feature_extractor(&self) -> GraphFeatureExtractor<'_, T, E> {
426        GraphFeatureExtractor::new(self)
427    }
428}
429
430/// Batch multiple graphs into a single tensor representation
431///
432/// Creates a batched tensor with shape [batch_size * max_nodes, num_features]
433/// and a batch adjacency matrix with appropriate offsets
434#[cfg(feature = "tensor")]
435pub struct GraphBatch {
436    graphs: Vec<(DenseTensor, GraphAdjacencyMatrix)>,
437}
438
439#[cfg(feature = "tensor")]
440impl GraphBatch {
441    /// Create new batch from graphs
442    pub fn new<T, E>(graphs: &[Graph<T, E>]) -> Result<Self, TensorError>
443    where
444        T: AsRef<[f64]> + Clone,
445        E: Clone,
446    {
447        let mut batch = Self {
448            graphs: Vec::with_capacity(graphs.len()),
449        };
450
451        for graph in graphs {
452            let (features, adjacency) = graph.to_tensor_representation()?;
453            batch.graphs.push((features, adjacency));
454        }
455
456        Ok(batch)
457    }
458
459    /// Get batched feature matrix
460    pub fn batch_features(&self) -> DenseTensor {
461        if self.graphs.is_empty() {
462            return DenseTensor::zeros(vec![0, 0]);
463        }
464
465        // Find max nodes and features
466        let max_nodes = self
467            .graphs
468            .iter()
469            .map(|(_, adj)| adj.num_nodes)
470            .max()
471            .unwrap_or(0);
472
473        let num_features = self
474            .graphs
475            .iter()
476            .map(|(feat, _)| feat.shape().get(1).copied().unwrap_or(1))
477            .max()
478            .unwrap_or(1);
479
480        // Concatenate with padding
481        let mut all_features = Vec::new();
482        for (features, adjacency) in &self.graphs {
483            let feat_data = features.data();
484            all_features.extend_from_slice(feat_data);
485
486            // Pad to max_nodes if needed
487            let current_nodes = adjacency.num_nodes;
488            if current_nodes < max_nodes {
489                let padding_size = (max_nodes - current_nodes) * num_features;
490                all_features.extend(std::iter::repeat_n(0.0, padding_size));
491            }
492        }
493
494        DenseTensor::from_vec(
495            all_features,
496            vec![self.graphs.len() * max_nodes, num_features],
497        )
498    }
499
500    /// Get batched adjacency matrix (block diagonal)
501    #[cfg(feature = "tensor")]
502    pub fn batch_adjacency(&self) -> GraphAdjacencyMatrix {
503        if self.graphs.is_empty() {
504            return GraphAdjacencyMatrix::from_edge_list(&[], 0, false).unwrap();
505        }
506
507        // For batch processing, we keep graphs separate and use offset indexing
508        // This is a simplified implementation - full block diagonal would be more complex
509        let total_nodes: usize = self.graphs.iter().map(|(_, adj)| adj.num_nodes).sum();
510        let total_edges: usize = self.graphs.iter().map(|(_, adj)| adj.num_edges).sum();
511
512        // Collect all edges with offsets
513        let mut all_edges = Vec::with_capacity(total_edges);
514        let mut offset = 0;
515
516        for (_, adjacency) in &self.graphs {
517            let csr = adjacency.as_sparse_tensor();
518            for src in 0..adjacency.num_nodes {
519                let start = csr.row_offsets()[src];
520                let end = csr.row_offsets()[src + 1];
521                for j in start..end {
522                    let dst = csr.col_indices()[j];
523                    all_edges.push((src + offset, dst + offset));
524                }
525            }
526            offset += adjacency.num_nodes;
527        }
528
529        GraphAdjacencyMatrix::from_edge_list(
530            &all_edges,
531            total_nodes,
532            self.graphs
533                .first()
534                .map(|(_, adj)| adj.is_directed)
535                .unwrap_or(false),
536        )
537        .unwrap()
538    }
539
540    /// Get number of graphs in batch
541    pub fn len(&self) -> usize {
542        self.graphs.len()
543    }
544
545    /// Check if batch is empty
546    pub fn is_empty(&self) -> bool {
547        self.graphs.is_empty()
548    }
549
550    /// Get individual graph by index
551    #[cfg(feature = "tensor")]
552    pub fn get(&self, idx: usize) -> Option<&(DenseTensor, GraphAdjacencyMatrix)> {
553        self.graphs.get(idx)
554    }
555}
556
557#[cfg(all(test, feature = "tensor"))]
558mod tests {
559    use super::*;
560    use crate::graph::Graph;
561
562    #[test]
563    fn test_adjacency_matrix_creation() {
564        let edges = vec![(0, 1), (1, 2), (2, 0)];
565        let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
566
567        assert_eq!(adj.num_nodes, 3);
568        assert_eq!(adj.num_edges, 3);
569        assert!(adj.is_directed);
570    }
571
572    #[test]
573    fn test_graph_to_tensor_conversion() {
574        let mut graph = Graph::<Vec<f64>, f64>::directed();
575
576        let n0 = graph.add_node(vec![1.0, 0.0]).unwrap();
577        let n1 = graph.add_node(vec![0.0, 1.0]).unwrap();
578        let n2 = graph.add_node(vec![1.0, 1.0]).unwrap();
579
580        let _ = graph.add_edge(n0, n1, 1.0);
581        let _ = graph.add_edge(n1, n2, 1.0);
582        let _ = graph.add_edge(n2, n0, 1.0);
583
584        let (features, adjacency) = graph.to_tensor_representation().unwrap();
585
586        assert_eq!(features.shape(), &[3, 2]);
587        assert_eq!(adjacency.num_nodes, 3);
588        assert_eq!(adjacency.num_edges, 3);
589    }
590
591    #[test]
592    fn test_feature_extractor() {
593        let mut graph = Graph::<String, f64>::directed();
594
595        let n0 = graph.add_node("node0".to_string()).unwrap();
596        let n1 = graph.add_node("node1".to_string()).unwrap();
597        let _ = graph.add_edge(n0, n1, 1.0);
598
599        let extractor = graph.feature_extractor();
600
601        // Extract scalar features (string length)
602        let features = extractor
603            .extract_node_features_scalar(|s| s.len() as f64)
604            .unwrap();
605
606        assert_eq!(features.shape(), &[2, 1]);
607    }
608
609    #[test]
610    fn test_graph_reconstruction() {
611        let edges = vec![(0, 1), (1, 2), (2, 0)];
612        let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
613
614        let reconstructor = GraphReconstructor::new(true);
615
616        let graph: Graph<usize, f64> = reconstructor
617            .from_adjacency(&adj, |i| i, |_src, _dst, w| w)
618            .unwrap();
619
620        assert_eq!(graph.node_count(), 3);
621        assert_eq!(graph.edge_count(), 3);
622    }
623
624    #[test]
625    fn test_normalized_adjacency() {
626        let edges = vec![(0, 1), (1, 0), (1, 2), (2, 1)];
627        let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
628
629        let normalized = adj.normalized_with_self_loops().unwrap();
630
631        // Should have self-loops added
632        assert!(normalized.num_edges > adj.num_edges);
633    }
634
635    #[test]
636    fn test_batch_creation() {
637        let mut graph1 = Graph::<Vec<f64>, f64>::directed();
638        let n0 = graph1.add_node(vec![1.0, 0.0]).unwrap();
639        let n1 = graph1.add_node(vec![0.0, 1.0]).unwrap();
640        let _ = graph1.add_edge(n0, n1, 1.0);
641
642        let mut graph2 = Graph::<Vec<f64>, f64>::directed();
643        let n0 = graph2.add_node(vec![1.0, 1.0]).unwrap();
644        let n1 = graph2.add_node(vec![0.0, 0.0]).unwrap();
645        let _ = graph2.add_edge(n0, n1, 1.0);
646
647        let batch = GraphBatch::new(&[graph1, graph2]).unwrap();
648
649        assert_eq!(batch.len(), 2);
650        assert!(!batch.is_empty());
651    }
652}