1use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
11use crate::graph::Graph;
12use crate::tensor::error::TensorError;
13use crate::tensor::traits::TensorBase;
14use crate::tensor::DenseTensor;
15
16#[cfg(feature = "tensor-sparse")]
17use crate::tensor::{COOTensor, CSRTensor};
18
19#[derive(Debug, Clone)]
21#[cfg(feature = "tensor-sparse")]
22pub struct GraphAdjacencyMatrix {
23 csr: CSRTensor,
25 pub num_nodes: usize,
27 pub num_edges: usize,
29 pub is_directed: bool,
31}
32
33#[cfg(feature = "tensor-sparse")]
34impl GraphAdjacencyMatrix {
35 pub fn from_edge_list(
37 edges: &[(usize, usize)],
38 num_nodes: usize,
39 is_directed: bool,
40 ) -> Result<Self, TensorError> {
41 if edges.is_empty() {
42 return Ok(Self {
43 csr: CSRTensor::new(
44 vec![0; num_nodes + 1],
45 vec![],
46 DenseTensor::zeros(vec![0]),
47 [num_nodes, num_nodes],
48 ),
49 num_nodes,
50 num_edges: 0,
51 is_directed,
52 });
53 }
54
55 let mut row_offsets = vec![0usize; num_nodes + 1];
57 let mut col_indices = Vec::with_capacity(edges.len());
58 let mut values_data = Vec::with_capacity(edges.len());
59
60 for &(src, _) in edges {
62 if src < num_nodes {
63 row_offsets[src + 1] += 1;
64 }
65 }
66
67 for i in 1..=num_nodes {
69 row_offsets[i] += row_offsets[i - 1];
70 }
71
72 let mut row_pos = row_offsets[..num_nodes].to_vec();
74 for &(src, dst) in edges {
75 if src < num_nodes && dst < num_nodes {
76 let _pos = row_pos[src];
77 col_indices.push(dst);
78 values_data.push(1.0);
79 row_pos[src] += 1;
80 }
81 }
82
83 let values = DenseTensor::new(values_data, vec![col_indices.len()]);
84 let csr = CSRTensor::new(row_offsets, col_indices, values, [num_nodes, num_nodes]);
85
86 Ok(Self {
87 csr,
88 num_nodes,
89 num_edges: edges.len(),
90 is_directed,
91 })
92 }
93
94 #[cfg(feature = "tensor-sparse")]
96 pub fn to_coo(&self) -> COOTensor {
97 use crate::tensor::SparseTensor;
98 let sparse = SparseTensor::CSR(self.csr.clone());
99 sparse.to_coo()
100 }
101
102 #[cfg(feature = "tensor-sparse")]
104 pub fn as_sparse_tensor(&self) -> &CSRTensor {
105 &self.csr
106 }
107
108 #[cfg(feature = "tensor-sparse")]
113 pub fn normalized_with_self_loops(&self) -> Result<Self, TensorError> {
114 let n = self.num_nodes;
115
116 let mut edges = Vec::new();
118
119 for i in 0..n {
121 let start = self.csr.row_offsets()[i];
122 let end = self.csr.row_offsets()[i + 1];
123 for j in start..end {
124 let col = self.csr.col_indices()[j];
125 edges.push((i, col));
126 }
127 edges.push((i, i));
129 }
130
131 Self::from_edge_list(&edges, n, self.is_directed)
132 }
133
134 #[cfg(feature = "tensor-sparse")]
136 pub fn degree_matrix(&self) -> DenseTensor {
137 let n = self.num_nodes;
138 let mut degrees = vec![0.0; n];
139
140 for (i, degree) in degrees.iter_mut().enumerate() {
141 let start = self.csr.row_offsets()[i];
142 let end = self.csr.row_offsets()[i + 1];
143 *degree = (end - start) as f64;
144 }
145
146 DenseTensor::from_vec(degrees, vec![n])
147 }
148
149 #[cfg(feature = "tensor-sparse")]
151 pub fn inverse_degree_matrix(&self) -> DenseTensor {
152 let n = self.num_nodes;
153 let mut inv_degrees = vec![0.0; n];
154
155 for (i, inv_degree) in inv_degrees.iter_mut().enumerate() {
156 let start = self.csr.row_offsets()[i];
157 let end = self.csr.row_offsets()[i + 1];
158 let degree = (end - start) as f64;
159 *inv_degree = if degree > 0.0 { 1.0 / degree } else { 0.0 };
160 }
161
162 DenseTensor::from_vec(inv_degrees, vec![n])
163 }
164}
165
166pub struct GraphFeatureExtractor<'a, T, E> {
168 graph: &'a Graph<T, E>,
169}
170
171impl<'a, T, E> GraphFeatureExtractor<'a, T, E>
172where
173 T: Clone,
174 E: Clone,
175{
176 pub fn new(graph: &'a Graph<T, E>) -> Self {
178 Self { graph }
179 }
180
181 pub fn extract_node_features_scalar<F>(&self, map_fn: F) -> Result<DenseTensor, TensorError>
185 where
186 F: Fn(&T) -> f64,
187 {
188 let n = self.graph.node_count();
189 let mut features = Vec::with_capacity(n);
190
191 for node_idx in self.graph.nodes() {
192 let data = node_idx.data();
193 features.push(map_fn(data));
194 }
195
196 Ok(DenseTensor::from_vec(features, vec![n, 1]))
197 }
198
199 pub fn extract_node_features<F>(
203 &self,
204 map_fn: F,
205 num_features: usize,
206 ) -> Result<DenseTensor, TensorError>
207 where
208 F: for<'b> Fn(&'b T) -> &'b [f64],
209 {
210 let n = self.graph.node_count();
211 let mut features = Vec::with_capacity(n * num_features);
212
213 for node_idx in self.graph.nodes() {
214 let data = node_idx.data();
215 let feat = map_fn(data);
216 features.extend_from_slice(feat);
217 }
218
219 Ok(DenseTensor::from_vec(features, vec![n, num_features]))
220 }
221
222 pub fn extract_edge_features<F>(&self, map_fn: F) -> Result<DenseTensor, TensorError>
224 where
225 F: Fn(&E) -> f64,
226 {
227 let m = self.graph.edge_count();
228 let mut features = Vec::with_capacity(m);
229
230 for edge_idx in self.graph.edges() {
231 let data = edge_idx.data();
232 features.push(map_fn(data));
233 }
234
235 Ok(DenseTensor::from_vec(features, vec![m, 1]))
236 }
237
238 #[cfg(feature = "tensor-sparse")]
240 pub fn extract_adjacency(&self) -> Result<GraphAdjacencyMatrix, TensorError> {
241 let mut edges: Vec<(usize, usize)> = Vec::new();
242
243 for node_idx in self.graph.nodes() {
244 let src = node_idx.index().index();
245 for neighbor in self.graph.neighbors(node_idx.index()) {
246 let dst = neighbor.index();
247 edges.push((src, dst));
248 }
249 }
250
251 GraphAdjacencyMatrix::from_edge_list(
252 &edges,
253 self.graph.node_count(),
254 true, )
256 }
257
258 #[cfg(feature = "tensor-sparse")]
260 pub fn extract_all(
261 &self,
262 num_node_features: usize,
263 ) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
264 where
265 T: AsRef<[f64]> + Clone,
266 E: Clone,
267 {
268 let node_features =
269 self.extract_node_features(|data: &T| data.as_ref(), num_node_features)?;
270 let adjacency = self.extract_adjacency()?;
271
272 Ok((node_features, adjacency))
273 }
274}
275
276pub struct GraphReconstructor {
278 directed: bool,
279}
280
281impl GraphReconstructor {
282 pub fn new(directed: bool) -> Self {
284 Self { directed }
285 }
286
287 #[cfg(feature = "tensor-sparse")]
289 pub fn from_adjacency<T, E>(
290 &self,
291 adjacency: &GraphAdjacencyMatrix,
292 mut node_factory: impl FnMut(usize) -> T,
293 mut edge_factory: impl FnMut(usize, usize, f64) -> E,
294 ) -> Result<Graph<T, E>, TensorError>
295 where
296 T: Clone,
297 E: Clone,
298 {
299 let mut graph = if self.directed {
300 Graph::<T, E>::directed()
301 } else {
302 Graph::<T, E>::undirected()
303 };
304
305 let n = adjacency.num_nodes;
306 let mut node_indices = Vec::with_capacity(n);
307
308 for i in 0..n {
310 let node = node_factory(i);
311 let idx = graph.add_node(node).map_err(|e| TensorError::SliceError {
312 description: format!("Failed to add node: {:?}", e),
313 })?;
314 node_indices.push(idx);
315 }
316
317 let csr = adjacency.as_sparse_tensor();
319
320 for src in 0..n {
321 let start = csr.row_offsets()[src];
322 let end = csr.row_offsets()[src + 1];
323
324 for j in start..end {
325 let dst = csr.col_indices()[j];
326 let weight = csr.values().data()[j];
327
328 if let (Some(src_idx), Some(dst_idx)) = (
329 node_indices.get(src).copied(),
330 node_indices.get(dst).copied(),
331 ) {
332 let edge_data = edge_factory(src, dst, weight);
333 let _ = graph.add_edge(src_idx, dst_idx, edge_data);
334 }
335 }
336 }
337
338 Ok(graph)
339 }
340
341 #[cfg(feature = "tensor-sparse")]
343 pub fn from_coo<T, E>(
344 &self,
345 coo: &COOTensor,
346 node_factory: impl FnMut(usize) -> T,
347 edge_factory: impl FnMut(usize, usize, f64) -> E,
348 ) -> Result<Graph<T, E>, TensorError>
349 where
350 T: Clone,
351 E: Clone,
352 {
353 let row_indices = coo.row_indices();
355 let col_indices = coo.col_indices();
356 let edges: Vec<(usize, usize)> = row_indices
357 .iter()
358 .zip(col_indices.iter())
359 .map(|(&r, &c)| (r, c))
360 .collect();
361
362 let shape = coo.shape_array();
363 let adjacency = GraphAdjacencyMatrix::from_edge_list(&edges, shape[0], self.directed)?;
364
365 self.from_adjacency(&adjacency, node_factory, edge_factory)
366 }
367}
368
369#[cfg(feature = "tensor-sparse")]
371pub trait GraphTensorExt<T, E> {
372 fn to_tensor_representation(&self) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
374 where
375 T: AsRef<[f64]> + Clone,
376 E: Clone;
377
378 fn adjacency_matrix(&self) -> Result<GraphAdjacencyMatrix, TensorError>;
380
381 fn node_features(&self, num_features: usize) -> Result<DenseTensor, TensorError>
383 where
384 T: AsRef<[f64]> + Clone;
385
386 fn feature_extractor(&self) -> GraphFeatureExtractor<'_, T, E>;
388}
389
390#[cfg(feature = "tensor-sparse")]
391impl<T, E> GraphTensorExt<T, E> for Graph<T, E>
392where
393 T: Clone,
394 E: Clone,
395{
396 fn to_tensor_representation(&self) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
397 where
398 T: AsRef<[f64]> + Clone,
399 E: Clone,
400 {
401 let extractor = GraphFeatureExtractor::new(self);
402 let num_features = if let Some(first_node) = self.nodes().next() {
403 first_node.data().as_ref().len()
404 } else {
405 1
406 };
407
408 extractor.extract_all(num_features)
409 }
410
411 fn adjacency_matrix(&self) -> Result<GraphAdjacencyMatrix, TensorError> {
412 let extractor = GraphFeatureExtractor::new(self);
413 extractor.extract_adjacency()
414 }
415
416 fn node_features(&self, num_features: usize) -> Result<DenseTensor, TensorError>
417 where
418 T: AsRef<[f64]> + Clone,
419 {
420 let extractor = GraphFeatureExtractor::new(self);
421 extractor.extract_node_features(|data: &T| data.as_ref(), num_features)
422 }
423
424 fn feature_extractor(&self) -> GraphFeatureExtractor<'_, T, E> {
425 GraphFeatureExtractor::new(self)
426 }
427}
428
429#[cfg(feature = "tensor-sparse")]
434pub struct GraphBatch {
435 graphs: Vec<(DenseTensor, GraphAdjacencyMatrix)>,
436}
437
438#[cfg(feature = "tensor-sparse")]
439impl GraphBatch {
440 pub fn new<T, E>(graphs: &[Graph<T, E>]) -> Result<Self, TensorError>
442 where
443 T: AsRef<[f64]> + Clone,
444 E: Clone,
445 {
446 let mut batch = Self {
447 graphs: Vec::with_capacity(graphs.len()),
448 };
449
450 for graph in graphs {
451 let (features, adjacency) = graph.to_tensor_representation()?;
452 batch.graphs.push((features, adjacency));
453 }
454
455 Ok(batch)
456 }
457
458 pub fn batch_features(&self) -> DenseTensor {
460 if self.graphs.is_empty() {
461 return DenseTensor::zeros(vec![0, 0]);
462 }
463
464 let max_nodes = self
466 .graphs
467 .iter()
468 .map(|(_, adj)| adj.num_nodes)
469 .max()
470 .unwrap_or(0);
471
472 let num_features = self
473 .graphs
474 .iter()
475 .map(|(feat, _)| feat.shape().get(1).copied().unwrap_or(1))
476 .max()
477 .unwrap_or(1);
478
479 let mut all_features = Vec::new();
481 for (features, adjacency) in &self.graphs {
482 let feat_data = features.data();
483 all_features.extend_from_slice(feat_data);
484
485 let current_nodes = adjacency.num_nodes;
487 if current_nodes < max_nodes {
488 let padding_size = (max_nodes - current_nodes) * num_features;
489 all_features.extend(std::iter::repeat_n(0.0, padding_size));
490 }
491 }
492
493 DenseTensor::from_vec(
494 all_features,
495 vec![self.graphs.len() * max_nodes, num_features],
496 )
497 }
498
499 #[cfg(feature = "tensor-sparse")]
501 pub fn batch_adjacency(&self) -> GraphAdjacencyMatrix {
502 if self.graphs.is_empty() {
503 return GraphAdjacencyMatrix::from_edge_list(&[], 0, false).unwrap();
504 }
505
506 let total_nodes: usize = self.graphs.iter().map(|(_, adj)| adj.num_nodes).sum();
509 let total_edges: usize = self.graphs.iter().map(|(_, adj)| adj.num_edges).sum();
510
511 let mut all_edges = Vec::with_capacity(total_edges);
513 let mut offset = 0;
514
515 for (_, adjacency) in &self.graphs {
516 let csr = adjacency.as_sparse_tensor();
517 for src in 0..adjacency.num_nodes {
518 let start = csr.row_offsets()[src];
519 let end = csr.row_offsets()[src + 1];
520 for j in start..end {
521 let dst = csr.col_indices()[j];
522 all_edges.push((src + offset, dst + offset));
523 }
524 }
525 offset += adjacency.num_nodes;
526 }
527
528 GraphAdjacencyMatrix::from_edge_list(
529 &all_edges,
530 total_nodes,
531 self.graphs
532 .first()
533 .map(|(_, adj)| adj.is_directed)
534 .unwrap_or(false),
535 )
536 .unwrap()
537 }
538
539 pub fn len(&self) -> usize {
541 self.graphs.len()
542 }
543
544 pub fn is_empty(&self) -> bool {
546 self.graphs.is_empty()
547 }
548
549 #[cfg(feature = "tensor-sparse")]
551 pub fn get(&self, idx: usize) -> Option<&(DenseTensor, GraphAdjacencyMatrix)> {
552 self.graphs.get(idx)
553 }
554}
555
556#[cfg(all(test, feature = "tensor-sparse"))]
557mod tests {
558 use super::*;
559 use crate::graph::Graph;
560
561 #[test]
562 fn test_adjacency_matrix_creation() {
563 let edges = vec![(0, 1), (1, 2), (2, 0)];
564 let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
565
566 assert_eq!(adj.num_nodes, 3);
567 assert_eq!(adj.num_edges, 3);
568 assert!(adj.is_directed);
569 }
570
571 #[test]
572 fn test_graph_to_tensor_conversion() {
573 let mut graph = Graph::<Vec<f64>, f64>::directed();
574
575 let n0 = graph.add_node(vec![1.0, 0.0]).unwrap();
576 let n1 = graph.add_node(vec![0.0, 1.0]).unwrap();
577 let n2 = graph.add_node(vec![1.0, 1.0]).unwrap();
578
579 let _ = graph.add_edge(n0, n1, 1.0);
580 let _ = graph.add_edge(n1, n2, 1.0);
581 let _ = graph.add_edge(n2, n0, 1.0);
582
583 let (features, adjacency) = graph.to_tensor_representation().unwrap();
584
585 assert_eq!(features.shape(), &[3, 2]);
586 assert_eq!(adjacency.num_nodes, 3);
587 assert_eq!(adjacency.num_edges, 3);
588 }
589
590 #[test]
591 fn test_feature_extractor() {
592 let mut graph = Graph::<String, f64>::directed();
593
594 let n0 = graph.add_node("node0".to_string()).unwrap();
595 let n1 = graph.add_node("node1".to_string()).unwrap();
596 let _ = graph.add_edge(n0, n1, 1.0);
597
598 let extractor = graph.feature_extractor();
599
600 let features = extractor
602 .extract_node_features_scalar(|s| s.len() as f64)
603 .unwrap();
604
605 assert_eq!(features.shape(), &[2, 1]);
606 }
607
608 #[test]
609 fn test_graph_reconstruction() {
610 let edges = vec![(0, 1), (1, 2), (2, 0)];
611 let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
612
613 let reconstructor = GraphReconstructor::new(true);
614
615 let graph: Graph<usize, f64> = reconstructor
616 .from_adjacency(&adj, |i| i, |_src, _dst, w| w)
617 .unwrap();
618
619 assert_eq!(graph.node_count(), 3);
620 assert_eq!(graph.edge_count(), 3);
621 }
622
623 #[test]
624 fn test_normalized_adjacency() {
625 let edges = vec![(0, 1), (1, 0), (1, 2), (2, 1)];
626 let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
627
628 let normalized = adj.normalized_with_self_loops().unwrap();
629
630 assert!(normalized.num_edges > adj.num_edges);
632 }
633
634 #[test]
635 fn test_batch_creation() {
636 let mut graph1 = Graph::<Vec<f64>, f64>::directed();
637 let n0 = graph1.add_node(vec![1.0, 0.0]).unwrap();
638 let n1 = graph1.add_node(vec![0.0, 1.0]).unwrap();
639 let _ = graph1.add_edge(n0, n1, 1.0);
640
641 let mut graph2 = Graph::<Vec<f64>, f64>::directed();
642 let n0 = graph2.add_node(vec![1.0, 1.0]).unwrap();
643 let n1 = graph2.add_node(vec![0.0, 0.0]).unwrap();
644 let _ = graph2.add_edge(n0, n1, 1.0);
645
646 let batch = GraphBatch::new(&[graph1, graph2]).unwrap();
647
648 assert_eq!(batch.len(), 2);
649 assert!(!batch.is_empty());
650 }
651}