1use crate::graph::traits::{GraphBase, GraphQuery, GraphOps};
11use crate::graph::Graph;
12use crate::tensor::DenseTensor;
13use crate::tensor::error::TensorError;
14use crate::tensor::traits::TensorBase;
15
16#[cfg(feature = "tensor-sparse")]
17use crate::tensor::{COOTensor, CSRTensor};
18
19#[derive(Debug, Clone)]
21pub struct GraphAdjacencyMatrix {
22 csr: CSRTensor,
24 pub num_nodes: usize,
26 pub num_edges: usize,
28 pub is_directed: bool,
30}
31
32impl GraphAdjacencyMatrix {
33 pub fn from_edge_list(
35 edges: &[(usize, usize)],
36 num_nodes: usize,
37 is_directed: bool,
38 ) -> Result<Self, TensorError> {
39 if edges.is_empty() {
40 return Ok(Self {
41 csr: CSRTensor::new(
42 vec![0; num_nodes + 1],
43 vec![],
44 DenseTensor::zeros(vec![0]),
45 [num_nodes, num_nodes],
46 ),
47 num_nodes,
48 num_edges: 0,
49 is_directed,
50 });
51 }
52
53 let mut row_offsets = vec![0usize; num_nodes + 1];
55 let mut col_indices = Vec::with_capacity(edges.len());
56 let mut values_data = Vec::with_capacity(edges.len());
57
58 for &(src, _) in edges {
60 if src < num_nodes {
61 row_offsets[src + 1] += 1;
62 }
63 }
64
65 for i in 1..=num_nodes {
67 row_offsets[i] += row_offsets[i - 1];
68 }
69
70 let mut row_pos = row_offsets[..num_nodes].to_vec();
72 for &(src, dst) in edges {
73 if src < num_nodes && dst < num_nodes {
74 let _pos = row_pos[src];
75 col_indices.push(dst);
76 values_data.push(1.0);
77 row_pos[src] += 1;
78 }
79 }
80
81 let values = DenseTensor::new(values_data, vec![col_indices.len()]);
82 let csr = CSRTensor::new(row_offsets, col_indices, values, [num_nodes, num_nodes]);
83
84 Ok(Self {
85 csr,
86 num_nodes,
87 num_edges: edges.len(),
88 is_directed,
89 })
90 }
91
92 #[cfg(feature = "tensor-sparse")]
94 pub fn to_coo(&self) -> COOTensor {
95 use crate::tensor::SparseTensor;
96 let sparse = SparseTensor::CSR(self.csr.clone());
97 sparse.to_coo()
98 }
99
100 pub fn as_sparse_tensor(&self) -> &CSRTensor {
102 &self.csr
103 }
104
105 pub fn normalized_with_self_loops(&self) -> Result<Self, TensorError> {
110 let n = self.num_nodes;
111
112 let mut edges = Vec::new();
114
115 for i in 0..n {
117 let start = self.csr.row_offsets()[i];
118 let end = self.csr.row_offsets()[i + 1];
119 for j in start..end {
120 let col = self.csr.col_indices()[j];
121 edges.push((i, col));
122 }
123 edges.push((i, i));
125 }
126
127 Self::from_edge_list(&edges, n, self.is_directed)
128 }
129
130 pub fn degree_matrix(&self) -> DenseTensor {
132 let n = self.num_nodes;
133 let mut degrees = vec![0.0; n];
134
135 for (i, degree) in degrees.iter_mut().enumerate() {
136 let start = self.csr.row_offsets()[i];
137 let end = self.csr.row_offsets()[i + 1];
138 *degree = (end - start) as f64;
139 }
140
141 DenseTensor::from_vec(degrees, vec![n])
142 }
143
144 pub fn inverse_degree_matrix(&self) -> DenseTensor {
146 let n = self.num_nodes;
147 let mut inv_degrees = vec![0.0; n];
148
149 for (i, inv_degree) in inv_degrees.iter_mut().enumerate() {
150 let start = self.csr.row_offsets()[i];
151 let end = self.csr.row_offsets()[i + 1];
152 let degree = (end - start) as f64;
153 *inv_degree = if degree > 0.0 { 1.0 / degree } else { 0.0 };
154 }
155
156 DenseTensor::from_vec(inv_degrees, vec![n])
157 }
158}
159
160pub struct GraphFeatureExtractor<'a, T, E> {
162 graph: &'a Graph<T, E>,
163}
164
165impl<'a, T, E> GraphFeatureExtractor<'a, T, E>
166where
167 T: Clone,
168 E: Clone,
169{
170 pub fn new(graph: &'a Graph<T, E>) -> Self {
172 Self { graph }
173 }
174
175 pub fn extract_node_features_scalar<F>(&self, map_fn: F) -> Result<DenseTensor, TensorError>
179 where
180 F: Fn(&T) -> f64,
181 {
182 let n = self.graph.node_count();
183 let mut features = Vec::with_capacity(n);
184
185 for node_idx in self.graph.nodes() {
186 let data = node_idx.data();
187 features.push(map_fn(data));
188 }
189
190 Ok(DenseTensor::from_vec(features, vec![n, 1]))
191 }
192
193 pub fn extract_node_features<F>(&self, map_fn: F, num_features: usize) -> Result<DenseTensor, TensorError>
197 where
198 F: for<'b> Fn(&'b T) -> &'b [f64],
199 {
200 let n = self.graph.node_count();
201 let mut features = Vec::with_capacity(n * num_features);
202
203 for node_idx in self.graph.nodes() {
204 let data = node_idx.data();
205 let feat = map_fn(data);
206 features.extend_from_slice(feat);
207 }
208
209 Ok(DenseTensor::from_vec(features, vec![n, num_features]))
210 }
211
212 pub fn extract_edge_features<F>(&self, map_fn: F) -> Result<DenseTensor, TensorError>
214 where
215 F: Fn(&E) -> f64,
216 {
217 let m = self.graph.edge_count();
218 let mut features = Vec::with_capacity(m);
219
220 for edge_idx in self.graph.edges() {
221 let data = edge_idx.data();
222 features.push(map_fn(data));
223 }
224
225 Ok(DenseTensor::from_vec(features, vec![m, 1]))
226 }
227
228 pub fn extract_adjacency(&self) -> Result<GraphAdjacencyMatrix, TensorError> {
230 let mut edges: Vec<(usize, usize)> = Vec::new();
231
232 for node_idx in self.graph.nodes() {
233 let src = node_idx.index().index();
234 for neighbor in self.graph.neighbors(node_idx.index()) {
235 let dst = neighbor.index();
236 edges.push((src, dst));
237 }
238 }
239
240 GraphAdjacencyMatrix::from_edge_list(
241 &edges,
242 self.graph.node_count(),
243 true, )
245 }
246
247 pub fn extract_all(&self, num_node_features: usize) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
249 where
250 T: AsRef<[f64]> + Clone,
251 E: Clone,
252 {
253 let node_features = self.extract_node_features(|data: &T| data.as_ref(), num_node_features)?;
254 let adjacency = self.extract_adjacency()?;
255
256 Ok((node_features, adjacency))
257 }
258}
259
260pub struct GraphReconstructor {
262 directed: bool,
263}
264
265impl GraphReconstructor {
266 pub fn new(directed: bool) -> Self {
268 Self { directed }
269 }
270
271 pub fn from_adjacency<T, E>(
273 &self,
274 adjacency: &GraphAdjacencyMatrix,
275 mut node_factory: impl FnMut(usize) -> T,
276 mut edge_factory: impl FnMut(usize, usize, f64) -> E,
277 ) -> Result<Graph<T, E>, TensorError>
278 where
279 T: Clone,
280 E: Clone,
281 {
282 let mut graph = if self.directed {
283 Graph::<T, E>::directed()
284 } else {
285 Graph::<T, E>::undirected()
286 };
287
288 let n = adjacency.num_nodes;
289 let mut node_indices = Vec::with_capacity(n);
290
291 for i in 0..n {
293 let node = node_factory(i);
294 let idx = graph.add_node(node)
295 .map_err(|e| TensorError::SliceError {
296 description: format!("Failed to add node: {:?}", e)
297 })?;
298 node_indices.push(idx);
299 }
300
301 let csr = adjacency.as_sparse_tensor();
303
304 for src in 0..n {
305 let start = csr.row_offsets()[src];
306 let end = csr.row_offsets()[src + 1];
307
308 for j in start..end {
309 let dst = csr.col_indices()[j];
310 let weight = csr.values().data()[j];
311
312 if let (Some(src_idx), Some(dst_idx)) = (
313 node_indices.get(src).copied(),
314 node_indices.get(dst).copied(),
315 ) {
316 let edge_data = edge_factory(src, dst, weight);
317 let _ = graph.add_edge(src_idx, dst_idx, edge_data);
318 }
319 }
320 }
321
322 Ok(graph)
323 }
324
325 pub fn from_coo<T, E>(
327 &self,
328 coo: &COOTensor,
329 node_factory: impl FnMut(usize) -> T,
330 edge_factory: impl FnMut(usize, usize, f64) -> E,
331 ) -> Result<Graph<T, E>, TensorError>
332 where
333 T: Clone,
334 E: Clone,
335 {
336 let row_indices = coo.row_indices();
338 let col_indices = coo.col_indices();
339 let edges: Vec<(usize, usize)> = row_indices.iter()
340 .zip(col_indices.iter())
341 .map(|(&r, &c)| (r, c))
342 .collect();
343
344 let shape = coo.shape_array();
345 let adjacency = GraphAdjacencyMatrix::from_edge_list(
346 &edges,
347 shape[0],
348 self.directed,
349 )?;
350
351 self.from_adjacency(&adjacency, node_factory, edge_factory)
352 }
353}
354
355pub trait GraphTensorExt<T, E> {
357 fn to_tensor_representation(&self) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
359 where
360 T: AsRef<[f64]> + Clone,
361 E: Clone;
362
363 fn adjacency_matrix(&self) -> Result<GraphAdjacencyMatrix, TensorError>;
365
366 fn node_features(&self, num_features: usize) -> Result<DenseTensor, TensorError>
368 where
369 T: AsRef<[f64]> + Clone;
370
371 fn feature_extractor(&self) -> GraphFeatureExtractor<'_, T, E>;
373}
374
375impl<T, E> GraphTensorExt<T, E> for Graph<T, E>
376where
377 T: Clone,
378 E: Clone,
379{
380 fn to_tensor_representation(&self) -> Result<(DenseTensor, GraphAdjacencyMatrix), TensorError>
381 where
382 T: AsRef<[f64]> + Clone,
383 E: Clone,
384 {
385 let extractor = GraphFeatureExtractor::new(self);
386 let num_features = if let Some(first_node) = self.nodes().next() {
387 first_node.data().as_ref().len()
388 } else {
389 1
390 };
391
392 extractor.extract_all(num_features)
393 }
394
395 fn adjacency_matrix(&self) -> Result<GraphAdjacencyMatrix, TensorError> {
396 let extractor = GraphFeatureExtractor::new(self);
397 extractor.extract_adjacency()
398 }
399
400 fn node_features(&self, num_features: usize) -> Result<DenseTensor, TensorError>
401 where
402 T: AsRef<[f64]> + Clone,
403 {
404 let extractor = GraphFeatureExtractor::new(self);
405 extractor.extract_node_features(|data: &T| data.as_ref(), num_features)
406 }
407
408 fn feature_extractor(&self) -> GraphFeatureExtractor<'_, T, E> {
409 GraphFeatureExtractor::new(self)
410 }
411}
412
413pub struct GraphBatch {
418 graphs: Vec<(DenseTensor, GraphAdjacencyMatrix)>,
419}
420
421impl GraphBatch {
422 pub fn new<T, E>(graphs: &[Graph<T, E>]) -> Result<Self, TensorError>
424 where
425 T: AsRef<[f64]> + Clone,
426 E: Clone,
427 {
428 let mut batch = Self { graphs: Vec::with_capacity(graphs.len()) };
429
430 for graph in graphs {
431 let (features, adjacency) = graph.to_tensor_representation()?;
432 batch.graphs.push((features, adjacency));
433 }
434
435 Ok(batch)
436 }
437
438 pub fn batch_features(&self) -> DenseTensor {
440 if self.graphs.is_empty() {
441 return DenseTensor::zeros(vec![0, 0]);
442 }
443
444 let max_nodes = self.graphs.iter()
446 .map(|(_, adj)| adj.num_nodes)
447 .max()
448 .unwrap_or(0);
449
450 let num_features = self.graphs.iter()
451 .map(|(feat, _)| feat.shape().get(1).copied().unwrap_or(1))
452 .max()
453 .unwrap_or(1);
454
455 let mut all_features = Vec::new();
457 for (features, adjacency) in &self.graphs {
458 let feat_data = features.data();
459 all_features.extend_from_slice(feat_data);
460
461 let current_nodes = adjacency.num_nodes;
463 if current_nodes < max_nodes {
464 let padding_size = (max_nodes - current_nodes) * num_features;
465 all_features.extend(std::iter::repeat_n(0.0, padding_size));
466 }
467 }
468
469 DenseTensor::from_vec(
470 all_features,
471 vec![self.graphs.len() * max_nodes, num_features],
472 )
473 }
474
475 pub fn batch_adjacency(&self) -> GraphAdjacencyMatrix {
477 if self.graphs.is_empty() {
478 return GraphAdjacencyMatrix::from_edge_list(&[], 0, false).unwrap();
479 }
480
481 let total_nodes: usize = self.graphs.iter().map(|(_, adj)| adj.num_nodes).sum();
484 let total_edges: usize = self.graphs.iter().map(|(_, adj)| adj.num_edges).sum();
485
486 let mut all_edges = Vec::with_capacity(total_edges);
488 let mut offset = 0;
489
490 for (_, adjacency) in &self.graphs {
491 let csr = adjacency.as_sparse_tensor();
492 for src in 0..adjacency.num_nodes {
493 let start = csr.row_offsets()[src];
494 let end = csr.row_offsets()[src + 1];
495 for j in start..end {
496 let dst = csr.col_indices()[j];
497 all_edges.push((src + offset, dst + offset));
498 }
499 }
500 offset += adjacency.num_nodes;
501 }
502
503 GraphAdjacencyMatrix::from_edge_list(
504 &all_edges,
505 total_nodes,
506 self.graphs.first().map(|(_, adj)| adj.is_directed).unwrap_or(false),
507 ).unwrap()
508 }
509
510 pub fn len(&self) -> usize {
512 self.graphs.len()
513 }
514
515 pub fn is_empty(&self) -> bool {
517 self.graphs.is_empty()
518 }
519
520 pub fn get(&self, idx: usize) -> Option<&(DenseTensor, GraphAdjacencyMatrix)> {
522 self.graphs.get(idx)
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use crate::graph::Graph;
530
531 #[test]
532 fn test_adjacency_matrix_creation() {
533 let edges = vec![(0, 1), (1, 2), (2, 0)];
534 let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
535
536 assert_eq!(adj.num_nodes, 3);
537 assert_eq!(adj.num_edges, 3);
538 assert!(adj.is_directed);
539 }
540
541 #[test]
542 fn test_graph_to_tensor_conversion() {
543 let mut graph = Graph::<Vec<f64>, f64>::directed();
544
545 let n0 = graph.add_node(vec![1.0, 0.0]).unwrap();
546 let n1 = graph.add_node(vec![0.0, 1.0]).unwrap();
547 let n2 = graph.add_node(vec![1.0, 1.0]).unwrap();
548
549 let _ = graph.add_edge(n0, n1, 1.0);
550 let _ = graph.add_edge(n1, n2, 1.0);
551 let _ = graph.add_edge(n2, n0, 1.0);
552
553 let (features, adjacency) = graph.to_tensor_representation().unwrap();
554
555 assert_eq!(features.shape(), &[3, 2]);
556 assert_eq!(adjacency.num_nodes, 3);
557 assert_eq!(adjacency.num_edges, 3);
558 }
559
560 #[test]
561 fn test_feature_extractor() {
562 let mut graph = Graph::<String, f64>::directed();
563
564 let n0 = graph.add_node("node0".to_string()).unwrap();
565 let n1 = graph.add_node("node1".to_string()).unwrap();
566 let _ = graph.add_edge(n0, n1, 1.0);
567
568 let extractor = graph.feature_extractor();
569
570 let features = extractor.extract_node_features_scalar(|s| s.len() as f64).unwrap();
572
573 assert_eq!(features.shape(), &[2, 1]);
574 }
575
576 #[test]
577 fn test_graph_reconstruction() {
578 let edges = vec![(0, 1), (1, 2), (2, 0)];
579 let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
580
581 let reconstructor = GraphReconstructor::new(true);
582
583 let graph: Graph<usize, f64> = reconstructor
584 .from_adjacency(
585 &adj,
586 |i| i,
587 |_src, _dst, w| w,
588 )
589 .unwrap();
590
591 assert_eq!(graph.node_count(), 3);
592 assert_eq!(graph.edge_count(), 3);
593 }
594
595 #[test]
596 fn test_normalized_adjacency() {
597 let edges = vec![(0, 1), (1, 0), (1, 2), (2, 1)];
598 let adj = GraphAdjacencyMatrix::from_edge_list(&edges, 3, true).unwrap();
599
600 let normalized = adj.normalized_with_self_loops().unwrap();
601
602 assert!(normalized.num_edges > adj.num_edges);
604 }
605
606 #[test]
607 fn test_batch_creation() {
608 let mut graph1 = Graph::<Vec<f64>, f64>::directed();
609 let n0 = graph1.add_node(vec![1.0, 0.0]).unwrap();
610 let n1 = graph1.add_node(vec![0.0, 1.0]).unwrap();
611 let _ = graph1.add_edge(n0, n1, 1.0);
612
613 let mut graph2 = Graph::<Vec<f64>, f64>::directed();
614 let n0 = graph2.add_node(vec![1.0, 1.0]).unwrap();
615 let n1 = graph2.add_node(vec![0.0, 0.0]).unwrap();
616 let _ = graph2.add_edge(n0, n1, 1.0);
617
618 let batch = GraphBatch::new(&[graph1, graph2]).unwrap();
619
620 assert_eq!(batch.len(), 2);
621 assert!(!batch.is_empty());
622 }
623}