Skip to main content

god_gragh/tensor/
graph_tensor.rs

1//! Graph-Tensor Integration: Specialized implementations for seamless graph-tensor conversion
2//!
3//! This module provides:
4//! - `TensorGraph`: Graph with native tensor node/edge features
5//! - `GraphToTensor`: Convert traditional graphs to tensor representations
6//! - `TensorToGraph`: Convert tensors back to graph structures
7//! - Adjacency matrix extraction and reconstruction
8//! - Feature matrix extraction for GNN workflows
9
10use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
11use crate::graph::Graph;
12use crate::tensor::error::TensorError;
13use crate::tensor::traits::TensorBase;
14use crate::tensor::DenseTensor;
15
16#[cfg(feature = "tensor-sparse")]
17use crate::tensor::{COOTensor, CSRTensor};
18
19/// Adjacency matrix representation for graph neural networks
20#[derive(Debug, Clone)]
21#[cfg(feature = "tensor-sparse")]
22pub struct GraphAdjacencyMatrix {
23    /// Sparse adjacency matrix in CSR format
24    csr: CSRTensor,
25    /// Number of nodes
26    pub num_nodes: usize,
27    /// Number of edges
28    pub num_edges: usize,
29    /// Whether the graph is directed
30    pub is_directed: bool,
31}
32
33#[cfg(feature = "tensor-sparse")]
34impl GraphAdjacencyMatrix {
35    /// Create adjacency matrix from edge list
36    pub fn from_edge_list(
37        edges: &[(usize, usize)],
38        num_nodes: usize,
39        is_directed: bool,
40    ) -> Result<Self, TensorError> {
41        if edges.is_empty() {
42            return Ok(Self {
43                csr: CSRTensor::new(
44                    vec![0; num_nodes + 1],
45                    vec![],
46                    DenseTensor::zeros(vec![0]),
47                    [num_nodes, num_nodes],
48                ),
49                num_nodes,
50                num_edges: 0,
51                is_directed,
52            });
53        }
54
55        // Build CSR format
56        let mut row_offsets = vec![0usize; num_nodes + 1];
57        let mut col_indices = Vec::with_capacity(edges.len());
58        let mut values_data = Vec::with_capacity(edges.len());
59
60        // Count edges per row
61        for &(src, _) in edges {
62            if src < num_nodes {
63                row_offsets[src + 1] += 1;
64            }
65        }
66
67        // Cumulative sum
68        for i in 1..=num_nodes {
69            row_offsets[i] += row_offsets[i - 1];
70        }
71
72        // Fill column indices and values
73        let mut row_pos = row_offsets[..num_nodes].to_vec();
74        for &(src, dst) in edges {
75            if src < num_nodes && dst < num_nodes {
76                let _pos = row_pos[src];
77                col_indices.push(dst);
78                values_data.push(1.0);
79                row_pos[src] += 1;
80            }
81        }
82
83        let values = DenseTensor::new(values_data, vec![col_indices.len()]);
84        let csr = CSRTensor::new(row_offsets, col_indices, values, [num_nodes, num_nodes]);
85
86        Ok(Self {
87            csr,
88            num_nodes,
89            num_edges: edges.len(),
90            is_directed,
91        })
92    }
93
94    /// Convert to COO format
95    #[cfg(feature = "tensor-sparse")]
96    pub fn to_coo(&self) -> COOTensor {
97        use crate::tensor::SparseTensor;
98        let sparse = SparseTensor::CSR(self.csr.clone());
99        sparse.to_coo()
100    }
101
102    /// Get sparse tensor representation
103    #[cfg(feature = "tensor-sparse")]
104    pub fn as_sparse_tensor(&self) -> &CSRTensor {
105        &self.csr
106    }
107
108    /// Compute normalized adjacency matrix (for GCN)
109    ///
110    /// Returns: D^(-1/2) * (A + I) * D^(-1/2)
111    /// where D is the degree matrix and I is the identity matrix
112    #[cfg(feature = "tensor-sparse")]
113    pub fn normalized_with_self_loops(&self) -> Result<Self, TensorError> {
114        let n = self.num_nodes;
115
116        // Add self-loops
117        let mut edges = Vec::new();
118
119        // Extract existing edges
120        for i in 0..n {
121            let start = self.csr.row_offsets()[i];
122            let end = self.csr.row_offsets()[i + 1];
123            for j in start..end {
124                let col = self.csr.col_indices()[j];
125                edges.push((i, col));
126            }
127            // Add self-loop
128            edges.push((i, i));
129        }
130
131        Self::from_edge_list(&edges, n, self.is_directed)
132    }
133
134    /// Compute degree matrix
135    #[cfg(feature = "tensor-sparse")]
136    pub fn degree_matrix(&self) -> DenseTensor {
137        let n = self.num_nodes;
138        let mut degrees = vec![0.0; n];
139
140        for (i, degree) in degrees.iter_mut().enumerate() {
141            let start = self.csr.row_offsets()[i];
142            let end = self.csr.row_offsets()[i + 1];
143            *degree = (end - start) as f64;
144        }
145
146        DenseTensor::from_vec(degrees, vec![n])
147    }
148
149    /// Compute inverse degree matrix (for normalization)
150    #[cfg(feature = "tensor-sparse")]
151    pub fn inverse_degree_matrix(&self) -> DenseTensor {
152        let n = self.num_nodes;
153        let mut inv_degrees = vec![0.0; n];
154
155        for (i, inv_degree) in inv_degrees.iter_mut().enumerate() {
156            let start = self.csr.row_offsets()[i];
157            let end = self.csr.row_offsets()[i + 1];
158            let degree = (end - start) as f64;
159            *inv_degree = if degree > 0.0 { 1.0 / degree } else { 0.0 };
160        }
161
162        DenseTensor::from_vec(inv_degrees, vec![n])
163    }
164}
165
166/// Feature extractor for converting graphs to tensor representations
167pub struct GraphFeatureExtractor<'a, T, E> {
168    graph: &'a Graph<T, E>,
169}
170
171impl<'a, T, E> GraphFeatureExtractor<'a, T, E>
172where
173    T: Clone,
174    E: Clone,
175{
176    /// Create new extractor from graph
177    pub fn new(graph: &'a Graph<T, E>) -> Self {
178        Self { graph }
179    }
180
181    /// Extract node features as dense tensor
182    ///
183    /// Each node's data is treated as a scalar feature
184    pub fn extract_node_features_scalar<F>(&self, map_fn: F) -> Result<DenseTensor, TensorError>
185    where
186        F: Fn(&T) -> f64,
187    {
188        let n = self.graph.node_count();
189        let mut features = Vec::with_capacity(n);
190
191        for node_idx in self.graph.nodes() {
192            let data = node_idx.data();
193            features.push(map_fn(data));
194        }
195
196        Ok(DenseTensor::from_vec(features, vec![n, 1]))
197    }
198
199    /// Extract node features as 2D tensor (nodes x features)
200    ///
201    /// Requires node data to be convertible to feature vectors
202    pub fn extract_node_features<F>(
203        &self,
204        map_fn: F,
205        num_features: usize,
206    ) -> Result<DenseTensor, TensorError>
207    where
208        F: for<'b> Fn(&'b T) -> &'b [f64],
209    {
210        let n = self.graph.node_count();
211        let mut features = Vec::with_capacity(n * num_features);
212
213        for node_idx in self.graph.nodes() {
214            let data = node_idx.data();
215            let feat = map_fn(data);
216            features.extend_from_slice(feat);
217        }
218
219        Ok(DenseTensor::from_vec(features, vec![n, num_features]))
220    }
221
222    /// Extract edge features as tensor
223    pub fn extract_edge_features<F>(&self, map_fn: F) -> Result<DenseTensor, TensorError>
224    where
225        F: Fn(&E) -> f64,
226    {
227        let m = self.graph.edge_count();
228        let mut features = Vec::with_capacity(m);
229
230        for edge_idx in self.graph.edges() {
231            let data = edge_idx.data();
232            features.push(map_fn(data));
233        }
234
235        Ok(DenseTensor::from_vec(features, vec![m, 1]))
236    }
237
238    /// Extract adjacency matrix as sparse tensor
239    #[cfg(feature = "tensor-sparse")]
240    pub fn extract_adjacency(&self) -> Result<GraphAdjacencyMatrix, TensorError> {
241        let mut edges: Vec<(usize, usize)> = Vec::new();
242
243        for node_idx in self.graph.nodes() {
244            let src = node_idx.index().index();
245            for neighbor in self.graph.neighbors(node_idx.index()) {
246                let dst = neighbor.index();
247                edges.push((src, dst));
248            }
249        }
250
251        GraphAdjacencyMatrix::from_edge_list(
252            &edges,
253            self.graph.node_count(),
254            true, // Assume directed for adjacency extraction
255        )
256    }
257
258    /// Extract complete graph as tensor representation
259    #[cfg(feature = "tensor-sparse")]
260    pub fn extract_all(
261        &self,
262        num_node_features: usize,
263    ) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
264    where
265        T: AsRef<[f64]> + Clone,
266        E: Clone,
267    {
268        let node_features =
269            self.extract_node_features(|data: &T| data.as_ref(), num_node_features)?;
270        let adjacency = self.extract_adjacency()?;
271
272        Ok((node_features, adjacency))
273    }
274}
275
276/// Reconstruct graph from tensor representations
277pub struct GraphReconstructor {
278    directed: bool,
279}
280
281impl GraphReconstructor {
282    /// Create new reconstructor
283    pub fn new(directed: bool) -> Self {
284        Self { directed }
285    }
286
287    /// Reconstruct graph from adjacency matrix
288    #[cfg(feature = "tensor-sparse")]
289    pub fn from_adjacency<T, E>(
290        &self,
291        adjacency: &GraphAdjacencyMatrix,
292        mut node_factory: impl FnMut(usize) -> T,
293        mut edge_factory: impl FnMut(usize, usize, f64) -> E,
294    ) -> Result<Graph<T, E>, TensorError>
295    where
296        T: Clone,
297        E: Clone,
298    {
299        let mut graph = if self.directed {
300            Graph::<T, E>::directed()
301        } else {
302            Graph::<T, E>::undirected()
303        };
304
305        let n = adjacency.num_nodes;
306        let mut node_indices = Vec::with_capacity(n);
307
308        // Create nodes
309        for i in 0..n {
310            let node = node_factory(i);
311            let idx = graph.add_node(node).map_err(|e| TensorError::SliceError {
312                description: format!("Failed to add node: {:?}", e),
313            })?;
314            node_indices.push(idx);
315        }
316
317        // Create edges from CSR
318        let csr = adjacency.as_sparse_tensor();
319
320        for src in 0..n {
321            let start = csr.row_offsets()[src];
322            let end = csr.row_offsets()[src + 1];
323
324            for j in start..end {
325                let dst = csr.col_indices()[j];
326                let weight = csr.values().data()[j];
327
328                if let (Some(src_idx), Some(dst_idx)) = (
329                    node_indices.get(src).copied(),
330                    node_indices.get(dst).copied(),
331                ) {
332                    let edge_data = edge_factory(src, dst, weight);
333                    let _ = graph.add_edge(src_idx, dst_idx, edge_data);
334                }
335            }
336        }
337
338        Ok(graph)
339    }
340
341    /// Reconstruct graph from COO tensor
342    #[cfg(feature = "tensor-sparse")]
343    pub fn from_coo<T, E>(
344        &self,
345        coo: &COOTensor,
346        node_factory: impl FnMut(usize) -> T,
347        edge_factory: impl FnMut(usize, usize, f64) -> E,
348    ) -> Result<Graph<T, E>, TensorError>
349    where
350        T: Clone,
351        E: Clone,
352    {
353        // Convert COO to edge list
354        let row_indices = coo.row_indices();
355        let col_indices = coo.col_indices();
356        let edges: Vec<(usize, usize)> = row_indices
357            .iter()
358            .zip(col_indices.iter())
359            .map(|(&r, &c)| (r, c))
360            .collect();
361
362        let shape = coo.shape_array();
363        let adjacency = GraphAdjacencyMatrix::from_edge_list(&edges, shape[0], self.directed)?;
364
365        self.from_adjacency(&adjacency, node_factory, edge_factory)
366    }
367}
368
369/// Extension trait for Graph to add tensor conversion methods
370#[cfg(feature = "tensor-sparse")]
371pub trait GraphTensorExt<T, E> {
372    /// Convert graph to tensor representation
373    fn to_tensor_representation(&self) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
374    where
375        T: AsRef<[f64]> + Clone,
376        E: Clone;
377
378    /// Get adjacency matrix as sparse tensor
379    fn adjacency_matrix(&self) -> Result<GraphAdjacencyMatrix, TensorError>;
380
381    /// Extract node features as tensor
382    fn node_features(&self, num_features: usize) -> Result<DenseTensor, TensorError>
383    where
384        T: AsRef<[f64]> + Clone;
385
386    /// Create feature extractor
387    fn feature_extractor(&self) -> GraphFeatureExtractor<'_, T, E>;
388}
389
390#[cfg(feature = "tensor-sparse")]
391impl<T, E> GraphTensorExt<T, E> for Graph<T, E>
392where
393    T: Clone,
394    E: Clone,
395{
396    fn to_tensor_representation(&self) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
397    where
398        T: AsRef<[f64]> + Clone,
399        E: Clone,
400    {
401        let extractor = GraphFeatureExtractor::new(self);
402        let num_features = if let Some(first_node) = self.nodes().next() {
403            first_node.data().as_ref().len()
404        } else {
405            1
406        };
407
408        extractor.extract_all(num_features)
409    }
410
411    fn adjacency_matrix(&self) -> Result<GraphAdjacencyMatrix, TensorError> {
412        let extractor = GraphFeatureExtractor::new(self);
413        extractor.extract_adjacency()
414    }
415
416    fn node_features(&self, num_features: usize) -> Result<DenseTensor, TensorError>
417    where
418        T: AsRef<[f64]> + Clone,
419    {
420        let extractor = GraphFeatureExtractor::new(self);
421        extractor.extract_node_features(|data: &T| data.as_ref(), num_features)
422    }
423
424    fn feature_extractor(&self) -> GraphFeatureExtractor<'_, T, E> {
425        GraphFeatureExtractor::new(self)
426    }
427}
428
429/// Batch multiple graphs into a single tensor representation
430///
431/// Creates a batched tensor with shape [batch_size * max_nodes, num_features]
432/// and a batch adjacency matrix with appropriate offsets
433#[cfg(feature = "tensor-sparse")]
434pub struct GraphBatch {
435    graphs: Vec<(DenseTensor, GraphAdjacencyMatrix)>,
436}
437
438#[cfg(feature = "tensor-sparse")]
439impl GraphBatch {
440    /// Create new batch from graphs
441    pub fn new<T, E>(graphs: &[Graph<T, E>]) -> Result<Self, TensorError>
442    where
443        T: AsRef<[f64]> + Clone,
444        E: Clone,
445    {
446        let mut batch = Self {
447            graphs: Vec::with_capacity(graphs.len()),
448        };
449
450        for graph in graphs {
451            let (features, adjacency) = graph.to_tensor_representation()?;
452            batch.graphs.push((features, adjacency));
453        }
454
455        Ok(batch)
456    }
457
458    /// Get batched feature matrix
459    pub fn batch_features(&self) -> DenseTensor {
460        if self.graphs.is_empty() {
461            return DenseTensor::zeros(vec![0, 0]);
462        }
463
464        // Find max nodes and features
465        let max_nodes = self
466            .graphs
467            .iter()
468            .map(|(_, adj)| adj.num_nodes)
469            .max()
470            .unwrap_or(0);
471
472        let num_features = self
473            .graphs
474            .iter()
475            .map(|(feat, _)| feat.shape().get(1).copied().unwrap_or(1))
476            .max()
477            .unwrap_or(1);
478
479        // Concatenate with padding
480        let mut all_features = Vec::new();
481        for (features, adjacency) in &self.graphs {
482            let feat_data = features.data();
483            all_features.extend_from_slice(feat_data);
484
485            // Pad to max_nodes if needed
486            let current_nodes = adjacency.num_nodes;
487            if current_nodes < max_nodes {
488                let padding_size = (max_nodes - current_nodes) * num_features;
489                all_features.extend(std::iter::repeat_n(0.0, padding_size));
490            }
491        }
492
493        DenseTensor::from_vec(
494            all_features,
495            vec![self.graphs.len() * max_nodes, num_features],
496        )
497    }
498
499    /// Get batched adjacency matrix (block diagonal)
500    #[cfg(feature = "tensor-sparse")]
501    pub fn batch_adjacency(&self) -> GraphAdjacencyMatrix {
502        if self.graphs.is_empty() {
503            return GraphAdjacencyMatrix::from_edge_list(&[], 0, false).unwrap();
504        }
505
506        // For batch processing, we keep graphs separate and use offset indexing
507        // This is a simplified implementation - full block diagonal would be more complex
508        let total_nodes: usize = self.graphs.iter().map(|(_, adj)| adj.num_nodes).sum();
509        let total_edges: usize = self.graphs.iter().map(|(_, adj)| adj.num_edges).sum();
510
511        // Collect all edges with offsets
512        let mut all_edges = Vec::with_capacity(total_edges);
513        let mut offset = 0;
514
515        for (_, adjacency) in &self.graphs {
516            let csr = adjacency.as_sparse_tensor();
517            for src in 0..adjacency.num_nodes {
518                let start = csr.row_offsets()[src];
519                let end = csr.row_offsets()[src + 1];
520                for j in start..end {
521                    let dst = csr.col_indices()[j];
522                    all_edges.push((src + offset, dst + offset));
523                }
524            }
525            offset += adjacency.num_nodes;
526        }
527
528        GraphAdjacencyMatrix::from_edge_list(
529            &all_edges,
530            total_nodes,
531            self.graphs
532                .first()
533                .map(|(_, adj)| adj.is_directed)
534                .unwrap_or(false),
535        )
536        .unwrap()
537    }
538
539    /// Get number of graphs in batch
540    pub fn len(&self) -> usize {
541        self.graphs.len()
542    }
543
544    /// Check if batch is empty
545    pub fn is_empty(&self) -> bool {
546        self.graphs.is_empty()
547    }
548
549    /// Get individual graph by index
550    #[cfg(feature = "tensor-sparse")]
551    pub fn get(&self, idx: usize) -> Option<&(DenseTensor, GraphAdjacencyMatrix)> {
552        self.graphs.get(idx)
553    }
554}
555
556#[cfg(all(test, feature = "tensor-sparse"))]
557mod tests {
558    use super::*;
559    use crate::graph::Graph;
560
561    #[test]
562    fn test_adjacency_matrix_creation() {
563        let edges = vec![(0, 1), (1, 2), (2, 0)];
564        let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
565
566        assert_eq!(adj.num_nodes, 3);
567        assert_eq!(adj.num_edges, 3);
568        assert!(adj.is_directed);
569    }
570
571    #[test]
572    fn test_graph_to_tensor_conversion() {
573        let mut graph = Graph::<Vec<f64>, f64>::directed();
574
575        let n0 = graph.add_node(vec![1.0, 0.0]).unwrap();
576        let n1 = graph.add_node(vec![0.0, 1.0]).unwrap();
577        let n2 = graph.add_node(vec![1.0, 1.0]).unwrap();
578
579        let _ = graph.add_edge(n0, n1, 1.0);
580        let _ = graph.add_edge(n1, n2, 1.0);
581        let _ = graph.add_edge(n2, n0, 1.0);
582
583        let (features, adjacency) = graph.to_tensor_representation().unwrap();
584
585        assert_eq!(features.shape(), &[3, 2]);
586        assert_eq!(adjacency.num_nodes, 3);
587        assert_eq!(adjacency.num_edges, 3);
588    }
589
590    #[test]
591    fn test_feature_extractor() {
592        let mut graph = Graph::<String, f64>::directed();
593
594        let n0 = graph.add_node("node0".to_string()).unwrap();
595        let n1 = graph.add_node("node1".to_string()).unwrap();
596        let _ = graph.add_edge(n0, n1, 1.0);
597
598        let extractor = graph.feature_extractor();
599
600        // Extract scalar features (string length)
601        let features = extractor
602            .extract_node_features_scalar(|s| s.len() as f64)
603            .unwrap();
604
605        assert_eq!(features.shape(), &[2, 1]);
606    }
607
608    #[test]
609    fn test_graph_reconstruction() {
610        let edges = vec![(0, 1), (1, 2), (2, 0)];
611        let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
612
613        let reconstructor = GraphReconstructor::new(true);
614
615        let graph: Graph<usize, f64> = reconstructor
616            .from_adjacency(&adj, |i| i, |_src, _dst, w| w)
617            .unwrap();
618
619        assert_eq!(graph.node_count(), 3);
620        assert_eq!(graph.edge_count(), 3);
621    }
622
623    #[test]
624    fn test_normalized_adjacency() {
625        let edges = vec![(0, 1), (1, 0), (1, 2), (2, 1)];
626        let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
627
628        let normalized = adj.normalized_with_self_loops().unwrap();
629
630        // Should have self-loops added
631        assert!(normalized.num_edges > adj.num_edges);
632    }
633
634    #[test]
635    fn test_batch_creation() {
636        let mut graph1 = Graph::<Vec<f64>, f64>::directed();
637        let n0 = graph1.add_node(vec![1.0, 0.0]).unwrap();
638        let n1 = graph1.add_node(vec![0.0, 1.0]).unwrap();
639        let _ = graph1.add_edge(n0, n1, 1.0);
640
641        let mut graph2 = Graph::<Vec<f64>, f64>::directed();
642        let n0 = graph2.add_node(vec![1.0, 1.0]).unwrap();
643        let n1 = graph2.add_node(vec![0.0, 0.0]).unwrap();
644        let _ = graph2.add_edge(n0, n1, 1.0);
645
646        let batch = GraphBatch::new(&[graph1, graph2]).unwrap();
647
648        assert_eq!(batch.len(), 2);
649        assert!(!batch.is_empty());
650    }
651}