datasynth_graph/exporters/
pytorch_geometric.rs1use std::collections::HashMap;
22use std::fs::{self, File};
23use std::io::{BufWriter, Write};
24use std::path::Path;
25
26use crate::exporters::common::{CommonExportConfig, CommonGraphMetadata};
27use crate::models::Graph;
28
29#[derive(Debug, Clone, Default)]
31pub struct PyGExportConfig {
32 pub common: CommonExportConfig,
34 pub one_hot_categoricals: bool,
36}
37
38pub type PyGMetadata = CommonGraphMetadata;
40
41pub struct PyGExporter {
43 config: PyGExportConfig,
44}
45
46impl PyGExporter {
47 pub fn new(config: PyGExportConfig) -> Self {
49 Self { config }
50 }
51
52 pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<PyGMetadata> {
54 fs::create_dir_all(output_dir)?;
55
56 let mut files = Vec::new();
57 let mut statistics = HashMap::new();
58
59 self.export_edge_index(graph, output_dir)?;
61 files.push("edge_index.npy".to_string());
62
63 if self.config.common.export_node_features {
65 let dim = self.export_node_features(graph, output_dir)?;
66 files.push("node_features.npy".to_string());
67 statistics.insert("node_feature_dim".to_string(), dim as f64);
68 }
69
70 if self.config.common.export_edge_features {
72 let dim = self.export_edge_features(graph, output_dir)?;
73 files.push("edge_features.npy".to_string());
74 statistics.insert("edge_feature_dim".to_string(), dim as f64);
75 }
76
77 if self.config.common.export_node_labels {
79 self.export_node_labels(graph, output_dir)?;
80 files.push("node_labels.npy".to_string());
81 }
82
83 if self.config.common.export_edge_labels {
85 self.export_edge_labels(graph, output_dir)?;
86 files.push("edge_labels.npy".to_string());
87 }
88
89 if self.config.common.export_masks {
91 self.export_masks(graph, output_dir)?;
92 files.push("train_mask.npy".to_string());
93 files.push("val_mask.npy".to_string());
94 files.push("test_mask.npy".to_string());
95 }
96
97 let node_types: HashMap<String, usize> = graph
99 .nodes_by_type
100 .keys()
101 .enumerate()
102 .map(|(i, t)| (t.as_str().to_string(), i))
103 .collect();
104
105 let edge_types: HashMap<String, usize> = graph
106 .edges_by_type
107 .keys()
108 .enumerate()
109 .map(|(i, t)| (t.as_str().to_string(), i))
110 .collect();
111
112 statistics.insert("density".to_string(), graph.metadata.density);
114 statistics.insert(
115 "anomalous_node_ratio".to_string(),
116 graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
117 );
118 statistics.insert(
119 "anomalous_edge_ratio".to_string(),
120 graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
121 );
122
123 let metadata = PyGMetadata {
125 name: graph.name.clone(),
126 num_nodes: graph.node_count(),
127 num_edges: graph.edge_count(),
128 node_feature_dim: graph.metadata.node_feature_dim,
129 edge_feature_dim: graph.metadata.edge_feature_dim,
130 num_node_classes: 2, num_edge_classes: 2,
132 node_types,
133 edge_types,
134 is_directed: true,
135 files,
136 statistics,
137 };
138
139 let metadata_path = output_dir.join("metadata.json");
141 let file = File::create(metadata_path)?;
142 serde_json::to_writer_pretty(file, &metadata)?;
143
144 self.write_loader_script(output_dir)?;
146
147 Ok(metadata)
148 }
149
150 fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
152 let (sources, targets) = graph.edge_index();
153
154 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
156 node_ids.sort();
157 let id_to_idx: HashMap<_, _> = node_ids
158 .iter()
159 .enumerate()
160 .map(|(i, &id)| (id, i))
161 .collect();
162
163 let sources_remapped: Vec<i64> = sources
165 .iter()
166 .map(|id| *id_to_idx.get(id).unwrap_or(&0) as i64)
167 .collect();
168 let targets_remapped: Vec<i64> = targets
169 .iter()
170 .map(|id| *id_to_idx.get(id).unwrap_or(&0) as i64)
171 .collect();
172
173 let path = output_dir.join("edge_index.npy");
175 self.write_npy_2d_i64(&path, &[sources_remapped, targets_remapped])?;
176
177 Ok(())
178 }
179
180 fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
182 let features = graph.node_features();
183 let dim = features.first().map(|f| f.len()).unwrap_or(0);
184
185 let path = output_dir.join("node_features.npy");
186 self.write_npy_2d_f64(&path, &features)?;
187
188 Ok(dim)
189 }
190
191 fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
193 let features = graph.edge_features();
194 let dim = features.first().map(|f| f.len()).unwrap_or(0);
195
196 let path = output_dir.join("edge_features.npy");
197 self.write_npy_2d_f64(&path, &features)?;
198
199 Ok(dim)
200 }
201
202 fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
204 let labels: Vec<i64> = graph
205 .node_anomaly_mask()
206 .iter()
207 .map(|&b| if b { 1 } else { 0 })
208 .collect();
209
210 let path = output_dir.join("node_labels.npy");
211 self.write_npy_1d_i64(&path, &labels)?;
212
213 Ok(())
214 }
215
216 fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
218 let labels: Vec<i64> = graph
219 .edge_anomaly_mask()
220 .iter()
221 .map(|&b| if b { 1 } else { 0 })
222 .collect();
223
224 let path = output_dir.join("edge_labels.npy");
225 self.write_npy_1d_i64(&path, &labels)?;
226
227 Ok(())
228 }
229
230 fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
232 let n = graph.node_count();
233 let mut rng = SimpleRng::new(self.config.common.seed);
234
235 let train_size = (n as f64 * self.config.common.train_ratio) as usize;
236 let val_size = (n as f64 * self.config.common.val_ratio) as usize;
237
238 let mut indices: Vec<usize> = (0..n).collect();
240 for i in (1..n).rev() {
241 let j = (rng.next() % (i as u64 + 1)) as usize;
242 indices.swap(i, j);
243 }
244
245 let mut train_mask = vec![false; n];
247 let mut val_mask = vec![false; n];
248 let mut test_mask = vec![false; n];
249
250 for (i, &idx) in indices.iter().enumerate() {
251 if i < train_size {
252 train_mask[idx] = true;
253 } else if i < train_size + val_size {
254 val_mask[idx] = true;
255 } else {
256 test_mask[idx] = true;
257 }
258 }
259
260 self.write_npy_1d_bool(&output_dir.join("train_mask.npy"), &train_mask)?;
262 self.write_npy_1d_bool(&output_dir.join("val_mask.npy"), &val_mask)?;
263 self.write_npy_1d_bool(&output_dir.join("test_mask.npy"), &test_mask)?;
264
265 Ok(())
266 }
267
268 fn write_npy_1d_i64(&self, path: &Path, data: &[i64]) -> std::io::Result<()> {
270 let file = File::create(path)?;
271 let mut writer = BufWriter::new(file);
272
273 let shape = format!("({},)", data.len());
275 self.write_npy_header(&mut writer, "<i8", &shape)?;
276
277 for &val in data {
279 writer.write_all(&val.to_le_bytes())?;
280 }
281
282 Ok(())
283 }
284
285 fn write_npy_1d_bool(&self, path: &Path, data: &[bool]) -> std::io::Result<()> {
287 let file = File::create(path)?;
288 let mut writer = BufWriter::new(file);
289
290 let shape = format!("({},)", data.len());
292 self.write_npy_header(&mut writer, "|b1", &shape)?;
293
294 for &val in data {
296 writer.write_all(&[if val { 1u8 } else { 0u8 }])?;
297 }
298
299 Ok(())
300 }
301
302 fn write_npy_2d_i64(&self, path: &Path, data: &[Vec<i64>]) -> std::io::Result<()> {
304 let file = File::create(path)?;
305 let mut writer = BufWriter::new(file);
306
307 let rows = data.len();
308 let cols = data.first().map(|r| r.len()).unwrap_or(0);
309
310 let shape = format!("({}, {})", rows, cols);
312 self.write_npy_header(&mut writer, "<i8", &shape)?;
313
314 for row in data {
316 for &val in row {
317 writer.write_all(&val.to_le_bytes())?;
318 }
319 }
320
321 Ok(())
322 }
323
324 fn write_npy_2d_f64(&self, path: &Path, data: &[Vec<f64>]) -> std::io::Result<()> {
326 let file = File::create(path)?;
327 let mut writer = BufWriter::new(file);
328
329 let rows = data.len();
330 let cols = data.first().map(|r| r.len()).unwrap_or(0);
331
332 let shape = format!("({}, {})", rows, cols);
334 self.write_npy_header(&mut writer, "<f8", &shape)?;
335
336 for row in data {
338 for &val in row {
339 writer.write_all(&val.to_le_bytes())?;
340 }
341 for _ in row.len()..cols {
343 writer.write_all(&0.0_f64.to_le_bytes())?;
344 }
345 }
346
347 Ok(())
348 }
349
350 fn write_npy_header<W: Write>(
352 &self,
353 writer: &mut W,
354 dtype: &str,
355 shape: &str,
356 ) -> std::io::Result<()> {
357 writer.write_all(&[0x93])?; writer.write_all(b"NUMPY")?;
360 writer.write_all(&[0x01, 0x00])?; let header = format!(
364 "{{'descr': '{}', 'fortran_order': False, 'shape': {} }}",
365 dtype, shape
366 );
367
368 let header_len = header.len();
370 let total_len = 10 + header_len + 1; let padding = (64 - (total_len % 64)) % 64;
372 let padded_len = header_len + 1 + padding;
373
374 writer.write_all(&(padded_len as u16).to_le_bytes())?;
375 writer.write_all(header.as_bytes())?;
376 for _ in 0..padding {
377 writer.write_all(b" ")?;
378 }
379 writer.write_all(b"\n")?;
380
381 Ok(())
382 }
383
384 fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
386 let script = r#"#!/usr/bin/env python3
387"""
388PyTorch Geometric Data Loader
389
390Auto-generated loader for graph data exported from synth-graph.
391"""
392
393import json
394import numpy as np
395import torch
396from pathlib import Path
397
398try:
399 from torch_geometric.data import Data
400 HAS_PYG = True
401except ImportError:
402 HAS_PYG = False
403 print("Warning: torch_geometric not installed. Install with: pip install torch-geometric")
404
405
406def load_graph(data_dir: str = ".") -> "Data":
407 """Load graph data into a PyTorch Geometric Data object."""
408 data_dir = Path(data_dir)
409
410 # Load metadata
411 with open(data_dir / "metadata.json") as f:
412 metadata = json.load(f)
413
414 # Load edge index
415 edge_index = torch.from_numpy(np.load(data_dir / "edge_index.npy")).long()
416
417 # Load node features (if available)
418 x = None
419 node_features_path = data_dir / "node_features.npy"
420 if node_features_path.exists():
421 x = torch.from_numpy(np.load(node_features_path)).float()
422
423 # Load edge features (if available)
424 edge_attr = None
425 edge_features_path = data_dir / "edge_features.npy"
426 if edge_features_path.exists():
427 edge_attr = torch.from_numpy(np.load(edge_features_path)).float()
428
429 # Load node labels (if available)
430 y = None
431 node_labels_path = data_dir / "node_labels.npy"
432 if node_labels_path.exists():
433 y = torch.from_numpy(np.load(node_labels_path)).long()
434
435 # Load masks (if available)
436 train_mask = None
437 val_mask = None
438 test_mask = None
439
440 if (data_dir / "train_mask.npy").exists():
441 train_mask = torch.from_numpy(np.load(data_dir / "train_mask.npy")).bool()
442 if (data_dir / "val_mask.npy").exists():
443 val_mask = torch.from_numpy(np.load(data_dir / "val_mask.npy")).bool()
444 if (data_dir / "test_mask.npy").exists():
445 test_mask = torch.from_numpy(np.load(data_dir / "test_mask.npy")).bool()
446
447 if not HAS_PYG:
448 return {
449 "edge_index": edge_index,
450 "x": x,
451 "edge_attr": edge_attr,
452 "y": y,
453 "train_mask": train_mask,
454 "val_mask": val_mask,
455 "test_mask": test_mask,
456 "metadata": metadata,
457 }
458
459 # Create PyG Data object
460 data = Data(
461 x=x,
462 edge_index=edge_index,
463 edge_attr=edge_attr,
464 y=y,
465 train_mask=train_mask,
466 val_mask=val_mask,
467 test_mask=test_mask,
468 )
469
470 # Store metadata
471 data.metadata = metadata
472
473 return data
474
475
476def print_summary(data_dir: str = "."):
477 """Print summary of the graph data."""
478 data = load_graph(data_dir)
479
480 if isinstance(data, dict):
481 metadata = data["metadata"]
482 print(f"Graph: {metadata['name']}")
483 print(f"Nodes: {metadata['num_nodes']}")
484 print(f"Edges: {metadata['num_edges']}")
485 print(f"Node features: {data['x'].shape if data['x'] is not None else 'None'}")
486 print(f"Edge features: {data['edge_attr'].shape if data['edge_attr'] is not None else 'None'}")
487 else:
488 print(f"Graph: {data.metadata['name']}")
489 print(f"Nodes: {data.num_nodes}")
490 print(f"Edges: {data.num_edges}")
491 print(f"Node features: {data.x.shape if data.x is not None else 'None'}")
492 print(f"Edge features: {data.edge_attr.shape if data.edge_attr is not None else 'None'}")
493 if data.y is not None:
494 print(f"Anomalous nodes: {data.y.sum().item()}")
495 if data.train_mask is not None:
496 print(f"Train/Val/Test: {data.train_mask.sum()}/{data.val_mask.sum()}/{data.test_mask.sum()}")
497
498
499if __name__ == "__main__":
500 import sys
501 data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
502 print_summary(data_dir)
503"#;
504
505 let path = output_dir.join("load_graph.py");
506 let mut file = File::create(path)?;
507 file.write_all(script.as_bytes())?;
508
509 Ok(())
510 }
511}
512
513struct SimpleRng {
515 state: u64,
516}
517
518impl SimpleRng {
519 fn new(seed: u64) -> Self {
520 Self {
521 state: if seed == 0 { 1 } else { seed },
522 }
523 }
524
525 fn next(&mut self) -> u64 {
526 let mut x = self.state;
527 x ^= x << 13;
528 x ^= x >> 7;
529 x ^= x << 17;
530 self.state = x;
531 x
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538 use crate::test_helpers::create_test_graph;
539 use tempfile::tempdir;
540
541 #[test]
542 fn test_pyg_export() {
543 let graph = create_test_graph();
544 let dir = tempdir().unwrap();
545
546 let exporter = PyGExporter::new(PyGExportConfig::default());
547 let metadata = exporter.export(&graph, dir.path()).unwrap();
548
549 assert_eq!(metadata.num_nodes, 2);
550 assert_eq!(metadata.num_edges, 1);
551 assert!(dir.path().join("edge_index.npy").exists());
552 assert!(dir.path().join("node_features.npy").exists());
553 assert!(dir.path().join("metadata.json").exists());
554 assert!(dir.path().join("load_graph.py").exists());
555 }
556}