1use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
11use crate::graph::Graph;
12use crate::tensor::error::TensorError;
13use crate::tensor::traits::TensorBase;
14use crate::tensor::DenseTensor;
15
16#[cfg(feature = "tensor")]
17use crate::tensor::{COOTensor, CSRTensor};
18
19#[derive(Debug, Clone)]
21#[cfg(feature = "tensor")]
22pub struct GraphAdjacencyMatrix {
23 csr: CSRTensor,
25 pub num_nodes: usize,
27 pub num_edges: usize,
29 pub is_directed: bool,
31}
32
33#[cfg(feature = "tensor")]
34impl GraphAdjacencyMatrix {
35 pub fn from_edge_list(
37 edges: &[(usize, usize)],
38 num_nodes: usize,
39 is_directed: bool,
40 ) -> Result<Self, TensorError> {
41 if edges.is_empty() {
42 return Ok(Self {
43 csr: CSRTensor::new(
44 vec![0; num_nodes + 1],
45 vec![],
46 DenseTensor::zeros(vec![0]),
47 [num_nodes, num_nodes],
48 ),
49 num_nodes,
50 num_edges: 0,
51 is_directed,
52 });
53 }
54
55 let mut row_offsets = vec![0usize; num_nodes + 1];
57 let mut col_indices = Vec::with_capacity(edges.len());
58 let mut values_data = Vec::with_capacity(edges.len());
59
60 for &(src, _) in edges {
62 if src < num_nodes {
63 row_offsets[src + 1] += 1;
64 }
65 }
66
67 for i in 1..=num_nodes {
69 row_offsets[i] += row_offsets[i - 1];
70 }
71
72 let mut row_pos = row_offsets[..num_nodes].to_vec();
74 for &(src, dst) in edges {
75 if src < num_nodes && dst < num_nodes {
76 let _pos = row_pos[src];
77 col_indices.push(dst);
78 values_data.push(1.0);
79 row_pos[src] += 1;
80 }
81 }
82
83 let values = DenseTensor::new(values_data, vec![col_indices.len()]);
84 let csr = CSRTensor::new(row_offsets, col_indices, values, [num_nodes, num_nodes]);
85
86 Ok(Self {
87 csr,
88 num_nodes,
89 num_edges: edges.len(),
90 is_directed,
91 })
92 }
93
94 #[cfg(feature = "tensor")]
96 pub fn to_coo(&self) -> COOTensor {
97 use crate::tensor::SparseTensor;
98 let sparse = SparseTensor::CSR(self.csr.clone());
99 sparse.to_coo()
100 }
101
102 #[cfg(feature = "tensor")]
104 pub fn as_sparse_tensor(&self) -> &CSRTensor {
105 &self.csr
106 }
107
108 #[cfg(feature = "tensor")]
113 pub fn normalized_with_self_loops(&self) -> Result<Self, TensorError> {
114 let n = self.num_nodes;
115
116 let mut edges = Vec::new();
118
119 for i in 0..n {
121 let start = self.csr.row_offsets()[i];
122 let end = self.csr.row_offsets()[i + 1];
123 for j in start..end {
124 let col = self.csr.col_indices()[j];
125 edges.push((i, col));
126 }
127 edges.push((i, i));
129 }
130
131 Self::from_edge_list(&edges, n, self.is_directed)
132 }
133
134 #[cfg(feature = "tensor")]
136 pub fn degree_matrix(&self) -> DenseTensor {
137 let n = self.num_nodes;
138 let mut degrees = vec![0.0; n];
139
140 for (i, degree) in degrees.iter_mut().enumerate() {
141 let start = self.csr.row_offsets()[i];
142 let end = self.csr.row_offsets()[i + 1];
143 *degree = (end - start) as f64;
144 }
145
146 DenseTensor::from_vec(degrees, vec![n])
147 }
148
149 #[cfg(feature = "tensor")]
151 pub fn inverse_degree_matrix(&self) -> DenseTensor {
152 let n = self.num_nodes;
153 let mut inv_degrees = vec![0.0; n];
154
155 for (i, inv_degree) in inv_degrees.iter_mut().enumerate() {
156 let start = self.csr.row_offsets()[i];
157 let end = self.csr.row_offsets()[i + 1];
158 let degree = (end - start) as f64;
159 *inv_degree = if degree > 0.0 { 1.0 / degree } else { 0.0 };
160 }
161
162 DenseTensor::from_vec(inv_degrees, vec![n])
163 }
164}
165
166pub struct GraphFeatureExtractor<'a, T, E> {
168 graph: &'a Graph<T, E>,
169}
170
171impl<'a, T, E> GraphFeatureExtractor<'a, T, E>
172where
173 T: Clone,
174 E: Clone,
175{
176 pub fn new(graph: &'a Graph<T, E>) -> Self {
178 Self { graph }
179 }
180
181 pub fn extract_node_features_scalar<F>(&self, map_fn: F) -> Result<DenseTensor, TensorError>
185 where
186 F: Fn(&T) -> f64,
187 {
188 let n = self.graph.node_count();
189 let mut features = Vec::with_capacity(n);
190
191 for node_idx in self.graph.nodes() {
192 let data = node_idx.data();
193 features.push(map_fn(data));
194 }
195
196 Ok(DenseTensor::from_vec(features, vec![n, 1]))
197 }
198
199 pub fn extract_node_features<F>(
203 &self,
204 map_fn: F,
205 num_features: usize,
206 ) -> Result<DenseTensor, TensorError>
207 where
208 F: for<'b> Fn(&'b T) -> &'b [f64],
209 {
210 let n = self.graph.node_count();
211 let mut features = Vec::with_capacity(n * num_features);
212
213 for node_idx in self.graph.nodes() {
214 let data = node_idx.data();
215 let feat = map_fn(data);
216 features.extend_from_slice(feat);
217 }
218
219 Ok(DenseTensor::from_vec(features, vec![n, num_features]))
220 }
221
222 pub fn extract_edge_features<F>(&self, map_fn: F) -> Result<DenseTensor, TensorError>
224 where
225 F: Fn(&E) -> f64,
226 {
227 let m = self.graph.edge_count();
228 let mut features = Vec::with_capacity(m);
229
230 for edge_idx in self.graph.edges() {
231 let data = edge_idx.data();
232 features.push(map_fn(data));
233 }
234
235 Ok(DenseTensor::from_vec(features, vec![m, 1]))
236 }
237
238 #[cfg(feature = "tensor")]
240 pub fn extract_adjacency(&self) -> Result<GraphAdjacencyMatrix, TensorError> {
241 let mut edges: Vec<(usize, usize)> = Vec::new();
242
243 for node_idx in self.graph.nodes() {
244 let src = node_idx.index().index();
245 for neighbor in self.graph.neighbors(node_idx.index()) {
246 let dst = neighbor.index();
247 edges.push((src, dst));
248 }
249 }
250
251 GraphAdjacencyMatrix::from_edge_list(
252 &edges,
253 self.graph.node_count(),
254 true, )
256 }
257
258 #[cfg(feature = "tensor")]
260 pub fn extract_all(
261 &self,
262 num_node_features: usize,
263 ) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
264 where
265 T: AsRef<[f64]> + Clone,
266 E: Clone,
267 {
268 let node_features =
269 self.extract_node_features(|data: &T| data.as_ref(), num_node_features)?;
270 let adjacency = self.extract_adjacency()?;
271
272 Ok((node_features, adjacency))
273 }
274}
275
276#[allow(dead_code)]
278pub struct GraphReconstructor {
279 directed: bool,
280}
281
282impl GraphReconstructor {
283 pub fn new(directed: bool) -> Self {
285 Self { directed }
286 }
287
288 #[cfg(feature = "tensor")]
290 pub fn from_adjacency<T, E>(
291 &self,
292 adjacency: &GraphAdjacencyMatrix,
293 mut node_factory: impl FnMut(usize) -> T,
294 mut edge_factory: impl FnMut(usize, usize, f64) -> E,
295 ) -> Result<Graph<T, E>, TensorError>
296 where
297 T: Clone,
298 E: Clone,
299 {
300 let mut graph = if self.directed {
301 Graph::<T, E>::directed()
302 } else {
303 Graph::<T, E>::undirected()
304 };
305
306 let n = adjacency.num_nodes;
307 let mut node_indices = Vec::with_capacity(n);
308
309 for i in 0..n {
311 let node = node_factory(i);
312 let idx = graph.add_node(node).map_err(|e| TensorError::SliceError {
313 description: format!("Failed to add node: {:?}", e),
314 })?;
315 node_indices.push(idx);
316 }
317
318 let csr = adjacency.as_sparse_tensor();
320
321 for src in 0..n {
322 let start = csr.row_offsets()[src];
323 let end = csr.row_offsets()[src + 1];
324
325 for j in start..end {
326 let dst = csr.col_indices()[j];
327 let weight = csr.values().data()[j];
328
329 if let (Some(src_idx), Some(dst_idx)) = (
330 node_indices.get(src).copied(),
331 node_indices.get(dst).copied(),
332 ) {
333 let edge_data = edge_factory(src, dst, weight);
334 let _ = graph.add_edge(src_idx, dst_idx, edge_data);
335 }
336 }
337 }
338
339 Ok(graph)
340 }
341
342 #[cfg(feature = "tensor")]
344 pub fn from_coo<T, E>(
345 &self,
346 coo: &COOTensor,
347 node_factory: impl FnMut(usize) -> T,
348 edge_factory: impl FnMut(usize, usize, f64) -> E,
349 ) -> Result<Graph<T, E>, TensorError>
350 where
351 T: Clone,
352 E: Clone,
353 {
354 let row_indices = coo.row_indices();
356 let col_indices = coo.col_indices();
357 let edges: Vec<(usize, usize)> = row_indices
358 .iter()
359 .zip(col_indices.iter())
360 .map(|(&r, &c)| (r, c))
361 .collect();
362
363 let shape = coo.shape_array();
364 let adjacency = GraphAdjacencyMatrix::from_edge_list(&edges, shape[0], self.directed)?;
365
366 self.from_adjacency(&adjacency, node_factory, edge_factory)
367 }
368}
369
370#[cfg(feature = "tensor")]
372pub trait GraphTensorExt<T, E> {
373 fn to_tensor_representation(&self) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
375 where
376 T: AsRef<[f64]> + Clone,
377 E: Clone;
378
379 fn adjacency_matrix(&self) -> Result<GraphAdjacencyMatrix, TensorError>;
381
382 fn node_features(&self, num_features: usize) -> Result<DenseTensor, TensorError>
384 where
385 T: AsRef<[f64]> + Clone;
386
387 fn feature_extractor(&self) -> GraphFeatureExtractor<'_, T, E>;
389}
390
391#[cfg(feature = "tensor")]
392impl<T, E> GraphTensorExt<T, E> for Graph<T, E>
393where
394 T: Clone,
395 E: Clone,
396{
397 fn to_tensor_representation(&self) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
398 where
399 T: AsRef<[f64]> + Clone,
400 E: Clone,
401 {
402 let extractor = GraphFeatureExtractor::new(self);
403 let num_features = if let Some(first_node) = self.nodes().next() {
404 first_node.data().as_ref().len()
405 } else {
406 1
407 };
408
409 extractor.extract_all(num_features)
410 }
411
412 fn adjacency_matrix(&self) -> Result<GraphAdjacencyMatrix, TensorError> {
413 let extractor = GraphFeatureExtractor::new(self);
414 extractor.extract_adjacency()
415 }
416
417 fn node_features(&self, num_features: usize) -> Result<DenseTensor, TensorError>
418 where
419 T: AsRef<[f64]> + Clone,
420 {
421 let extractor = GraphFeatureExtractor::new(self);
422 extractor.extract_node_features(|data: &T| data.as_ref(), num_features)
423 }
424
425 fn feature_extractor(&self) -> GraphFeatureExtractor<'_, T, E> {
426 GraphFeatureExtractor::new(self)
427 }
428}
429
430#[cfg(feature = "tensor")]
435pub struct GraphBatch {
436 graphs: Vec<(DenseTensor, GraphAdjacencyMatrix)>,
437}
438
439#[cfg(feature = "tensor")]
440impl GraphBatch {
441 pub fn new<T, E>(graphs: &[Graph<T, E>]) -> Result<Self, TensorError>
443 where
444 T: AsRef<[f64]> + Clone,
445 E: Clone,
446 {
447 let mut batch = Self {
448 graphs: Vec::with_capacity(graphs.len()),
449 };
450
451 for graph in graphs {
452 let (features, adjacency) = graph.to_tensor_representation()?;
453 batch.graphs.push((features, adjacency));
454 }
455
456 Ok(batch)
457 }
458
459 pub fn batch_features(&self) -> DenseTensor {
461 if self.graphs.is_empty() {
462 return DenseTensor::zeros(vec![0, 0]);
463 }
464
465 let max_nodes = self
467 .graphs
468 .iter()
469 .map(|(_, adj)| adj.num_nodes)
470 .max()
471 .unwrap_or(0);
472
473 let num_features = self
474 .graphs
475 .iter()
476 .map(|(feat, _)| feat.shape().get(1).copied().unwrap_or(1))
477 .max()
478 .unwrap_or(1);
479
480 let mut all_features = Vec::new();
482 for (features, adjacency) in &self.graphs {
483 let feat_data = features.data();
484 all_features.extend_from_slice(feat_data);
485
486 let current_nodes = adjacency.num_nodes;
488 if current_nodes < max_nodes {
489 let padding_size = (max_nodes - current_nodes) * num_features;
490 all_features.extend(std::iter::repeat_n(0.0, padding_size));
491 }
492 }
493
494 DenseTensor::from_vec(
495 all_features,
496 vec![self.graphs.len() * max_nodes, num_features],
497 )
498 }
499
500 #[cfg(feature = "tensor")]
502 pub fn batch_adjacency(&self) -> GraphAdjacencyMatrix {
503 if self.graphs.is_empty() {
504 return GraphAdjacencyMatrix::from_edge_list(&[], 0, false).unwrap();
505 }
506
507 let total_nodes: usize = self.graphs.iter().map(|(_, adj)| adj.num_nodes).sum();
510 let total_edges: usize = self.graphs.iter().map(|(_, adj)| adj.num_edges).sum();
511
512 let mut all_edges = Vec::with_capacity(total_edges);
514 let mut offset = 0;
515
516 for (_, adjacency) in &self.graphs {
517 let csr = adjacency.as_sparse_tensor();
518 for src in 0..adjacency.num_nodes {
519 let start = csr.row_offsets()[src];
520 let end = csr.row_offsets()[src + 1];
521 for j in start..end {
522 let dst = csr.col_indices()[j];
523 all_edges.push((src + offset, dst + offset));
524 }
525 }
526 offset += adjacency.num_nodes;
527 }
528
529 GraphAdjacencyMatrix::from_edge_list(
530 &all_edges,
531 total_nodes,
532 self.graphs
533 .first()
534 .map(|(_, adj)| adj.is_directed)
535 .unwrap_or(false),
536 )
537 .unwrap()
538 }
539
540 pub fn len(&self) -> usize {
542 self.graphs.len()
543 }
544
545 pub fn is_empty(&self) -> bool {
547 self.graphs.is_empty()
548 }
549
550 #[cfg(feature = "tensor")]
552 pub fn get(&self, idx: usize) -> Option<&(DenseTensor, GraphAdjacencyMatrix)> {
553 self.graphs.get(idx)
554 }
555}
556
557#[cfg(all(test, feature = "tensor"))]
558mod tests {
559 use super::*;
560 use crate::graph::Graph;
561
562 #[test]
563 fn test_adjacency_matrix_creation() {
564 let edges = vec![(0, 1), (1, 2), (2, 0)];
565 let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
566
567 assert_eq!(adj.num_nodes, 3);
568 assert_eq!(adj.num_edges, 3);
569 assert!(adj.is_directed);
570 }
571
572 #[test]
573 fn test_graph_to_tensor_conversion() {
574 let mut graph = Graph::<Vec<f64>, f64>::directed();
575
576 let n0 = graph.add_node(vec![1.0, 0.0]).unwrap();
577 let n1 = graph.add_node(vec![0.0, 1.0]).unwrap();
578 let n2 = graph.add_node(vec![1.0, 1.0]).unwrap();
579
580 let _ = graph.add_edge(n0, n1, 1.0);
581 let _ = graph.add_edge(n1, n2, 1.0);
582 let _ = graph.add_edge(n2, n0, 1.0);
583
584 let (features, adjacency) = graph.to_tensor_representation().unwrap();
585
586 assert_eq!(features.shape(), &[3, 2]);
587 assert_eq!(adjacency.num_nodes, 3);
588 assert_eq!(adjacency.num_edges, 3);
589 }
590
591 #[test]
592 fn test_feature_extractor() {
593 let mut graph = Graph::<String, f64>::directed();
594
595 let n0 = graph.add_node("node0".to_string()).unwrap();
596 let n1 = graph.add_node("node1".to_string()).unwrap();
597 let _ = graph.add_edge(n0, n1, 1.0);
598
599 let extractor = graph.feature_extractor();
600
601 let features = extractor
603 .extract_node_features_scalar(|s| s.len() as f64)
604 .unwrap();
605
606 assert_eq!(features.shape(), &[2, 1]);
607 }
608
609 #[test]
610 fn test_graph_reconstruction() {
611 let edges = vec![(0, 1), (1, 2), (2, 0)];
612 let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
613
614 let reconstructor = GraphReconstructor::new(true);
615
616 let graph: Graph<usize, f64> = reconstructor
617 .from_adjacency(&adj, |i| i, |_src, _dst, w| w)
618 .unwrap();
619
620 assert_eq!(graph.node_count(), 3);
621 assert_eq!(graph.edge_count(), 3);
622 }
623
624 #[test]
625 fn test_normalized_adjacency() {
626 let edges = vec![(0, 1), (1, 0), (1, 2), (2, 1)];
627 let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
628
629 let normalized = adj.normalized_with_self_loops().unwrap();
630
631 assert!(normalized.num_edges > adj.num_edges);
633 }
634
635 #[test]
636 fn test_batch_creation() {
637 let mut graph1 = Graph::<Vec<f64>, f64>::directed();
638 let n0 = graph1.add_node(vec![1.0, 0.0]).unwrap();
639 let n1 = graph1.add_node(vec![0.0, 1.0]).unwrap();
640 let _ = graph1.add_edge(n0, n1, 1.0);
641
642 let mut graph2 = Graph::<Vec<f64>, f64>::directed();
643 let n0 = graph2.add_node(vec![1.0, 1.0]).unwrap();
644 let n1 = graph2.add_node(vec![0.0, 0.0]).unwrap();
645 let _ = graph2.add_edge(n0, n1, 1.0);
646
647 let batch = GraphBatch::new(&[graph1, graph2]).unwrap();
648
649 assert_eq!(batch.len(), 2);
650 assert!(!batch.is_empty());
651 }
652}