Skip to main content

god_gragh/tensor/
graph_tensor.rs

1//! Graph-Tensor Integration: Specialized implementations for seamless graph-tensor conversion
2//!
3//! This module provides:
4//! - `TensorGraph`: Graph with native tensor node/edge features
5//! - `GraphToTensor`: Convert traditional graphs to tensor representations
6//! - `TensorToGraph`: Convert tensors back to graph structures
7//! - Adjacency matrix extraction and reconstruction
8//! - Feature matrix extraction for GNN workflows
9
10use crate::graph::traits::{GraphBase, GraphQuery, GraphOps};
11use crate::graph::Graph;
12use crate::tensor::DenseTensor;
13use crate::tensor::error::TensorError;
14use crate::tensor::traits::TensorBase;
15
16#[cfg(feature = "tensor-sparse")]
17use crate::tensor::{COOTensor, CSRTensor};
18
19/// Adjacency matrix representation for graph neural networks
20#[derive(Debug, Clone)]
21pub struct GraphAdjacencyMatrix {
22    /// Sparse adjacency matrix in CSR format
23    csr: CSRTensor,
24    /// Number of nodes
25    pub num_nodes: usize,
26    /// Number of edges
27    pub num_edges: usize,
28    /// Whether the graph is directed
29    pub is_directed: bool,
30}
31
32impl GraphAdjacencyMatrix {
33    /// Create adjacency matrix from edge list
34    pub fn from_edge_list(
35        edges: &[(usize, usize)],
36        num_nodes: usize,
37        is_directed: bool,
38    ) -> Result<Self, TensorError> {
39        if edges.is_empty() {
40            return Ok(Self {
41                csr: CSRTensor::new(
42                    vec![0; num_nodes + 1],
43                    vec![],
44                    DenseTensor::zeros(vec![0]),
45                    [num_nodes, num_nodes],
46                ),
47                num_nodes,
48                num_edges: 0,
49                is_directed,
50            });
51        }
52
53        // Build CSR format
54        let mut row_offsets = vec![0usize; num_nodes + 1];
55        let mut col_indices = Vec::with_capacity(edges.len());
56        let mut values_data = Vec::with_capacity(edges.len());
57
58        // Count edges per row
59        for &(src, _) in edges {
60            if src < num_nodes {
61                row_offsets[src + 1] += 1;
62            }
63        }
64
65        // Cumulative sum
66        for i in 1..=num_nodes {
67            row_offsets[i] += row_offsets[i - 1];
68        }
69
70        // Fill column indices and values
71        let mut row_pos = row_offsets[..num_nodes].to_vec();
72        for &(src, dst) in edges {
73            if src < num_nodes && dst < num_nodes {
74                let _pos = row_pos[src];
75                col_indices.push(dst);
76                values_data.push(1.0);
77                row_pos[src] += 1;
78            }
79        }
80
81        let values = DenseTensor::new(values_data, vec![col_indices.len()]);
82        let csr = CSRTensor::new(row_offsets, col_indices, values, [num_nodes, num_nodes]);
83
84        Ok(Self {
85            csr,
86            num_nodes,
87            num_edges: edges.len(),
88            is_directed,
89        })
90    }
91
92    /// Convert to COO format
93    #[cfg(feature = "tensor-sparse")]
94    pub fn to_coo(&self) -> COOTensor {
95        use crate::tensor::SparseTensor;
96        let sparse = SparseTensor::CSR(self.csr.clone());
97        sparse.to_coo()
98    }
99
100    /// Get sparse tensor representation
101    pub fn as_sparse_tensor(&self) -> &CSRTensor {
102        &self.csr
103    }
104
105    /// Compute normalized adjacency matrix (for GCN)
106    /// 
107    /// Returns: D^(-1/2) * (A + I) * D^(-1/2)
108    /// where D is the degree matrix and I is the identity matrix
109    pub fn normalized_with_self_loops(&self) -> Result<Self, TensorError> {
110        let n = self.num_nodes;
111        
112        // Add self-loops
113        let mut edges = Vec::new();
114        
115        // Extract existing edges
116        for i in 0..n {
117            let start = self.csr.row_offsets()[i];
118            let end = self.csr.row_offsets()[i + 1];
119            for j in start..end {
120                let col = self.csr.col_indices()[j];
121                edges.push((i, col));
122            }
123            // Add self-loop
124            edges.push((i, i));
125        }
126
127        Self::from_edge_list(&edges, n, self.is_directed)
128    }
129
130    /// Compute degree matrix
131    pub fn degree_matrix(&self) -> DenseTensor {
132        let n = self.num_nodes;
133        let mut degrees = vec![0.0; n];
134
135        for (i, degree) in degrees.iter_mut().enumerate() {
136            let start = self.csr.row_offsets()[i];
137            let end = self.csr.row_offsets()[i + 1];
138            *degree = (end - start) as f64;
139        }
140
141        DenseTensor::from_vec(degrees, vec![n])
142    }
143
144    /// Compute inverse degree matrix (for normalization)
145    pub fn inverse_degree_matrix(&self) -> DenseTensor {
146        let n = self.num_nodes;
147        let mut inv_degrees = vec![0.0; n];
148
149        for (i, inv_degree) in inv_degrees.iter_mut().enumerate() {
150            let start = self.csr.row_offsets()[i];
151            let end = self.csr.row_offsets()[i + 1];
152            let degree = (end - start) as f64;
153            *inv_degree = if degree > 0.0 { 1.0 / degree } else { 0.0 };
154        }
155
156        DenseTensor::from_vec(inv_degrees, vec![n])
157    }
158}
159
160/// Feature extractor for converting graphs to tensor representations
161pub struct GraphFeatureExtractor<'a, T, E> {
162    graph: &'a Graph<T, E>,
163}
164
165impl<'a, T, E> GraphFeatureExtractor<'a, T, E>
166where
167    T: Clone,
168    E: Clone,
169{
170    /// Create new extractor from graph
171    pub fn new(graph: &'a Graph<T, E>) -> Self {
172        Self { graph }
173    }
174
175    /// Extract node features as dense tensor
176    /// 
177    /// Each node's data is treated as a scalar feature
178    pub fn extract_node_features_scalar<F>(&self, map_fn: F) -> Result<DenseTensor, TensorError>
179    where
180        F: Fn(&T) -> f64,
181    {
182        let n = self.graph.node_count();
183        let mut features = Vec::with_capacity(n);
184
185        for node_idx in self.graph.nodes() {
186            let data = node_idx.data();
187            features.push(map_fn(data));
188        }
189
190        Ok(DenseTensor::from_vec(features, vec![n, 1]))
191    }
192
193    /// Extract node features as 2D tensor (nodes x features)
194    /// 
195    /// Requires node data to be convertible to feature vectors
196    pub fn extract_node_features<F>(&self, map_fn: F, num_features: usize) -> Result<DenseTensor, TensorError>
197    where
198        F: for<'b> Fn(&'b T) -> &'b [f64],
199    {
200        let n = self.graph.node_count();
201        let mut features = Vec::with_capacity(n * num_features);
202
203        for node_idx in self.graph.nodes() {
204            let data = node_idx.data();
205            let feat = map_fn(data);
206            features.extend_from_slice(feat);
207        }
208
209        Ok(DenseTensor::from_vec(features, vec![n, num_features]))
210    }
211
212    /// Extract edge features as tensor
213    pub fn extract_edge_features<F>(&self, map_fn: F) -> Result<DenseTensor, TensorError>
214    where
215        F: Fn(&E) -> f64,
216    {
217        let m = self.graph.edge_count();
218        let mut features = Vec::with_capacity(m);
219
220        for edge_idx in self.graph.edges() {
221            let data = edge_idx.data();
222            features.push(map_fn(data));
223        }
224
225        Ok(DenseTensor::from_vec(features, vec![m, 1]))
226    }
227
228    /// Extract adjacency matrix as sparse tensor
229    pub fn extract_adjacency(&self) -> Result<GraphAdjacencyMatrix, TensorError> {
230        let mut edges: Vec<(usize, usize)> = Vec::new();
231        
232        for node_idx in self.graph.nodes() {
233            let src = node_idx.index().index();
234            for neighbor in self.graph.neighbors(node_idx.index()) {
235                let dst = neighbor.index();
236                edges.push((src, dst));
237            }
238        }
239
240        GraphAdjacencyMatrix::from_edge_list(
241            &edges,
242            self.graph.node_count(),
243            true, // Assume directed for adjacency extraction
244        )
245    }
246
247    /// Extract complete graph as tensor representation
248    pub fn extract_all(&self, num_node_features: usize) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
249    where
250        T: AsRef<[f64]> + Clone,
251        E: Clone,
252    {
253        let node_features = self.extract_node_features(|data: &T| data.as_ref(), num_node_features)?;
254        let adjacency = self.extract_adjacency()?;
255
256        Ok((node_features, adjacency))
257    }
258}
259
260/// Reconstruct graph from tensor representations
261pub struct GraphReconstructor {
262    directed: bool,
263}
264
265impl GraphReconstructor {
266    /// Create new reconstructor
267    pub fn new(directed: bool) -> Self {
268        Self { directed }
269    }
270
271    /// Reconstruct graph from adjacency matrix
272    pub fn from_adjacency<T, E>(
273        &self,
274        adjacency: &GraphAdjacencyMatrix,
275        mut node_factory: impl FnMut(usize) -> T,
276        mut edge_factory: impl FnMut(usize, usize, f64) -> E,
277    ) -> Result<Graph<T, E>, TensorError>
278    where
279        T: Clone,
280        E: Clone,
281    {
282        let mut graph = if self.directed {
283            Graph::<T, E>::directed()
284        } else {
285            Graph::<T, E>::undirected()
286        };
287
288        let n = adjacency.num_nodes;
289        let mut node_indices = Vec::with_capacity(n);
290
291        // Create nodes
292        for i in 0..n {
293            let node = node_factory(i);
294            let idx = graph.add_node(node)
295                .map_err(|e| TensorError::SliceError { 
296                    description: format!("Failed to add node: {:?}", e)
297                })?;
298            node_indices.push(idx);
299        }
300
301        // Create edges from CSR
302        let csr = adjacency.as_sparse_tensor();
303        
304        for src in 0..n {
305            let start = csr.row_offsets()[src];
306            let end = csr.row_offsets()[src + 1];
307            
308            for j in start..end {
309                let dst = csr.col_indices()[j];
310                let weight = csr.values().data()[j];
311                
312                if let (Some(src_idx), Some(dst_idx)) = (
313                    node_indices.get(src).copied(),
314                    node_indices.get(dst).copied(),
315                ) {
316                    let edge_data = edge_factory(src, dst, weight);
317                    let _ = graph.add_edge(src_idx, dst_idx, edge_data);
318                }
319            }
320        }
321
322        Ok(graph)
323    }
324
325    /// Reconstruct graph from COO tensor
326    pub fn from_coo<T, E>(
327        &self,
328        coo: &COOTensor,
329        node_factory: impl FnMut(usize) -> T,
330        edge_factory: impl FnMut(usize, usize, f64) -> E,
331    ) -> Result<Graph<T, E>, TensorError>
332    where
333        T: Clone,
334        E: Clone,
335    {
336        // Convert COO to edge list
337        let row_indices = coo.row_indices();
338        let col_indices = coo.col_indices();
339        let edges: Vec<(usize, usize)> = row_indices.iter()
340            .zip(col_indices.iter())
341            .map(|(&r, &c)| (r, c))
342            .collect();
343
344        let shape = coo.shape_array();
345        let adjacency = GraphAdjacencyMatrix::from_edge_list(
346            &edges,
347            shape[0],
348            self.directed,
349        )?;
350
351        self.from_adjacency(&adjacency, node_factory, edge_factory)
352    }
353}
354
355/// Extension trait for Graph to add tensor conversion methods
356pub trait GraphTensorExt<T, E> {
357    /// Convert graph to tensor representation
358    fn to_tensor_representation(&self) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
359    where
360        T: AsRef<[f64]> + Clone,
361        E: Clone;
362
363    /// Get adjacency matrix as sparse tensor
364    fn adjacency_matrix(&self) -> Result<GraphAdjacencyMatrix, TensorError>;
365
366    /// Extract node features as tensor
367    fn node_features(&self, num_features: usize) -> Result<DenseTensor, TensorError>
368    where
369        T: AsRef<[f64]> + Clone;
370
371    /// Create feature extractor
372    fn feature_extractor(&self) -> GraphFeatureExtractor<'_, T, E>;
373}
374
375impl<T, E> GraphTensorExt<T, E> for Graph<T, E>
376where
377    T: Clone,
378    E: Clone,
379{
380    fn to_tensor_representation(&self) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
381    where
382        T: AsRef<[f64]> + Clone,
383        E: Clone,
384    {
385        let extractor = GraphFeatureExtractor::new(self);
386        let num_features = if let Some(first_node) = self.nodes().next() {
387            first_node.data().as_ref().len()
388        } else {
389            1
390        };
391        
392        extractor.extract_all(num_features)
393    }
394
395    fn adjacency_matrix(&self) -> Result<GraphAdjacencyMatrix, TensorError> {
396        let extractor = GraphFeatureExtractor::new(self);
397        extractor.extract_adjacency()
398    }
399
400    fn node_features(&self, num_features: usize) -> Result<DenseTensor, TensorError>
401    where
402        T: AsRef<[f64]> + Clone,
403    {
404        let extractor = GraphFeatureExtractor::new(self);
405        extractor.extract_node_features(|data: &T| data.as_ref(), num_features)
406    }
407
408    fn feature_extractor(&self) -> GraphFeatureExtractor<'_, T, E> {
409        GraphFeatureExtractor::new(self)
410    }
411}
412
413/// Batch multiple graphs into a single tensor representation
414/// 
415/// Creates a batched tensor with shape [batch_size * max_nodes, num_features]
416/// and a batch adjacency matrix with appropriate offsets
417pub struct GraphBatch {
418    graphs: Vec<(DenseTensor, GraphAdjacencyMatrix)>,
419}
420
421impl GraphBatch {
422    /// Create new batch from graphs
423    pub fn new<T, E>(graphs: &[Graph<T, E>]) -> Result<Self, TensorError>
424    where
425        T: AsRef<[f64]> + Clone,
426        E: Clone,
427    {
428        let mut batch = Self { graphs: Vec::with_capacity(graphs.len()) };
429        
430        for graph in graphs {
431            let (features, adjacency) = graph.to_tensor_representation()?;
432            batch.graphs.push((features, adjacency));
433        }
434        
435        Ok(batch)
436    }
437
438    /// Get batched feature matrix
439    pub fn batch_features(&self) -> DenseTensor {
440        if self.graphs.is_empty() {
441            return DenseTensor::zeros(vec![0, 0]);
442        }
443
444        // Find max nodes and features
445        let max_nodes = self.graphs.iter()
446            .map(|(_, adj)| adj.num_nodes)
447            .max()
448            .unwrap_or(0);
449        
450        let num_features = self.graphs.iter()
451            .map(|(feat, _)| feat.shape().get(1).copied().unwrap_or(1))
452            .max()
453            .unwrap_or(1);
454
455        // Concatenate with padding
456        let mut all_features = Vec::new();
457        for (features, adjacency) in &self.graphs {
458            let feat_data = features.data();
459            all_features.extend_from_slice(feat_data);
460
461            // Pad to max_nodes if needed
462            let current_nodes = adjacency.num_nodes;
463            if current_nodes < max_nodes {
464                let padding_size = (max_nodes - current_nodes) * num_features;
465                all_features.extend(std::iter::repeat_n(0.0, padding_size));
466            }
467        }
468
469        DenseTensor::from_vec(
470            all_features,
471            vec![self.graphs.len() * max_nodes, num_features],
472        )
473    }
474
475    /// Get batched adjacency matrix (block diagonal)
476    pub fn batch_adjacency(&self) -> GraphAdjacencyMatrix {
477        if self.graphs.is_empty() {
478            return GraphAdjacencyMatrix::from_edge_list(&[], 0, false).unwrap();
479        }
480
481        // For batch processing, we keep graphs separate and use offset indexing
482        // This is a simplified implementation - full block diagonal would be more complex
483        let total_nodes: usize = self.graphs.iter().map(|(_, adj)| adj.num_nodes).sum();
484        let total_edges: usize = self.graphs.iter().map(|(_, adj)| adj.num_edges).sum();
485
486        // Collect all edges with offsets
487        let mut all_edges = Vec::with_capacity(total_edges);
488        let mut offset = 0;
489
490        for (_, adjacency) in &self.graphs {
491            let csr = adjacency.as_sparse_tensor();
492            for src in 0..adjacency.num_nodes {
493                let start = csr.row_offsets()[src];
494                let end = csr.row_offsets()[src + 1];
495                for j in start..end {
496                    let dst = csr.col_indices()[j];
497                    all_edges.push((src + offset, dst + offset));
498                }
499            }
500            offset += adjacency.num_nodes;
501        }
502
503        GraphAdjacencyMatrix::from_edge_list(
504            &all_edges,
505            total_nodes,
506            self.graphs.first().map(|(_, adj)| adj.is_directed).unwrap_or(false),
507        ).unwrap()
508    }
509
510    /// Get number of graphs in batch
511    pub fn len(&self) -> usize {
512        self.graphs.len()
513    }
514
515    /// Check if batch is empty
516    pub fn is_empty(&self) -> bool {
517        self.graphs.is_empty()
518    }
519
520    /// Get individual graph by index
521    pub fn get(&self, idx: usize) -> Option<&(DenseTensor, GraphAdjacencyMatrix)> {
522        self.graphs.get(idx)
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529    use crate::graph::Graph;
530
531    #[test]
532    fn test_adjacency_matrix_creation() {
533        let edges = vec![(0, 1), (1, 2), (2, 0)];
534        let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
535        
536        assert_eq!(adj.num_nodes, 3);
537        assert_eq!(adj.num_edges, 3);
538        assert!(adj.is_directed);
539    }
540
541    #[test]
542    fn test_graph_to_tensor_conversion() {
543        let mut graph = Graph::<Vec<f64>, f64>::directed();
544        
545        let n0 = graph.add_node(vec![1.0, 0.0]).unwrap();
546        let n1 = graph.add_node(vec![0.0, 1.0]).unwrap();
547        let n2 = graph.add_node(vec![1.0, 1.0]).unwrap();
548        
549        let _ = graph.add_edge(n0, n1, 1.0);
550        let _ = graph.add_edge(n1, n2, 1.0);
551        let _ = graph.add_edge(n2, n0, 1.0);
552
553        let (features, adjacency) = graph.to_tensor_representation().unwrap();
554        
555        assert_eq!(features.shape(), &[3, 2]);
556        assert_eq!(adjacency.num_nodes, 3);
557        assert_eq!(adjacency.num_edges, 3);
558    }
559
560    #[test]
561    fn test_feature_extractor() {
562        let mut graph = Graph::<String, f64>::directed();
563        
564        let n0 = graph.add_node("node0".to_string()).unwrap();
565        let n1 = graph.add_node("node1".to_string()).unwrap();
566        let _ = graph.add_edge(n0, n1, 1.0);
567
568        let extractor = graph.feature_extractor();
569        
570        // Extract scalar features (string length)
571        let features = extractor.extract_node_features_scalar(|s| s.len() as f64).unwrap();
572        
573        assert_eq!(features.shape(), &[2, 1]);
574    }
575
576    #[test]
577    fn test_graph_reconstruction() {
578        let edges = vec![(0, 1), (1, 2), (2, 0)];
579        let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
580
581        let reconstructor = GraphReconstructor::new(true);
582
583        let graph: Graph<usize, f64> = reconstructor
584            .from_adjacency(
585                &adj,
586                |i| i,
587                |_src, _dst, w| w,
588            )
589            .unwrap();
590
591        assert_eq!(graph.node_count(), 3);
592        assert_eq!(graph.edge_count(), 3);
593    }
594
595    #[test]
596    fn test_normalized_adjacency() {
597        let edges = vec![(0, 1), (1, 0), (1, 2), (2, 1)];
598        let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
599        
600        let normalized = adj.normalized_with_self_loops().unwrap();
601        
602        // Should have self-loops added
603        assert!(normalized.num_edges > adj.num_edges);
604    }
605
606    #[test]
607    fn test_batch_creation() {
608        let mut graph1 = Graph::<Vec<f64>, f64>::directed();
609        let n0 = graph1.add_node(vec![1.0, 0.0]).unwrap();
610        let n1 = graph1.add_node(vec![0.0, 1.0]).unwrap();
611        let _ = graph1.add_edge(n0, n1, 1.0);
612
613        let mut graph2 = Graph::<Vec<f64>, f64>::directed();
614        let n0 = graph2.add_node(vec![1.0, 1.0]).unwrap();
615        let n1 = graph2.add_node(vec![0.0, 0.0]).unwrap();
616        let _ = graph2.add_edge(n0, n1, 1.0);
617
618        let batch = GraphBatch::new(&[graph1, graph2]).unwrap();
619        
620        assert_eq!(batch.len(), 2);
621        assert!(!batch.is_empty());
622    }
623}