datasynth_graph/exporters/
pytorch_geometric.rs1use std::collections::HashMap;
22use std::fs::{self, File};
23use std::io::Write;
24use std::path::Path;
25
26use crate::exporters::common::{CommonExportConfig, CommonGraphMetadata};
27use crate::exporters::npy_writer;
28use crate::models::Graph;
29
30#[derive(Debug, Clone, Default)]
32pub struct PyGExportConfig {
33 pub common: CommonExportConfig,
35 pub one_hot_categoricals: bool,
37}
38
39pub type PyGMetadata = CommonGraphMetadata;
41
42pub struct PyGExporter {
44 config: PyGExportConfig,
45}
46
47impl PyGExporter {
48 pub fn new(config: PyGExportConfig) -> Self {
50 Self { config }
51 }
52
53 pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<PyGMetadata> {
55 fs::create_dir_all(output_dir)?;
56
57 let mut files = Vec::new();
58 let mut statistics = HashMap::new();
59
60 self.export_edge_index(graph, output_dir)?;
62 files.push("edge_index.npy".to_string());
63
64 if self.config.common.export_node_features {
66 let dim = self.export_node_features(graph, output_dir)?;
67 files.push("node_features.npy".to_string());
68 statistics.insert("node_feature_dim".to_string(), dim as f64);
69 }
70
71 if self.config.common.export_edge_features {
73 let dim = self.export_edge_features(graph, output_dir)?;
74 files.push("edge_features.npy".to_string());
75 statistics.insert("edge_feature_dim".to_string(), dim as f64);
76 }
77
78 if self.config.common.export_node_labels {
80 self.export_node_labels(graph, output_dir)?;
81 files.push("node_labels.npy".to_string());
82 }
83
84 if self.config.common.export_edge_labels {
86 self.export_edge_labels(graph, output_dir)?;
87 files.push("edge_labels.npy".to_string());
88 }
89
90 if self.config.common.export_masks {
92 self.export_masks(graph, output_dir)?;
93 files.push("train_mask.npy".to_string());
94 files.push("val_mask.npy".to_string());
95 files.push("test_mask.npy".to_string());
96 }
97
98 let node_types: HashMap<String, usize> = graph
100 .nodes_by_type
101 .keys()
102 .enumerate()
103 .map(|(i, t)| (t.as_str().to_string(), i))
104 .collect();
105
106 let edge_types: HashMap<String, usize> = graph
107 .edges_by_type
108 .keys()
109 .enumerate()
110 .map(|(i, t)| (t.as_str().to_string(), i))
111 .collect();
112
113 statistics.insert("density".to_string(), graph.metadata.density);
115 statistics.insert(
116 "anomalous_node_ratio".to_string(),
117 graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
118 );
119 statistics.insert(
120 "anomalous_edge_ratio".to_string(),
121 graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
122 );
123
124 let metadata = PyGMetadata {
126 name: graph.name.clone(),
127 num_nodes: graph.node_count(),
128 num_edges: graph.edge_count(),
129 node_feature_dim: graph.metadata.node_feature_dim,
130 edge_feature_dim: graph.metadata.edge_feature_dim,
131 num_node_classes: 2, num_edge_classes: 2,
133 node_types,
134 edge_types,
135 is_directed: true,
136 files,
137 statistics,
138 };
139
140 let metadata_path = output_dir.join("metadata.json");
142 let file = File::create(metadata_path)?;
143 serde_json::to_writer_pretty(file, &metadata)?;
144
145 self.write_loader_script(output_dir)?;
147
148 Ok(metadata)
149 }
150
151 fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
153 let (sources, targets) = graph.edge_index();
154
155 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
157 node_ids.sort();
158 let id_to_idx: HashMap<_, _> = node_ids
159 .iter()
160 .enumerate()
161 .map(|(i, &id)| (id, i))
162 .collect();
163
164 let mut sources_remapped: Vec<i64> = Vec::with_capacity(sources.len());
166 let mut targets_remapped: Vec<i64> = Vec::with_capacity(targets.len());
167 let mut skipped_edges = 0usize;
168
169 for (src, dst) in sources.iter().zip(targets.iter()) {
170 match (id_to_idx.get(src), id_to_idx.get(dst)) {
171 (Some(&s), Some(&d)) => {
172 sources_remapped.push(s as i64);
173 targets_remapped.push(d as i64);
174 }
175 _ => {
176 skipped_edges += 1;
177 }
178 }
179 }
180 if skipped_edges > 0 {
181 tracing::warn!(
182 "PyTorch Geometric export: skipped {} edges with missing node IDs",
183 skipped_edges
184 );
185 }
186
187 let path = output_dir.join("edge_index.npy");
189 npy_writer::write_npy_2d_i64(&path, &[sources_remapped, targets_remapped])?;
190
191 Ok(())
192 }
193
194 fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
196 let features = graph.node_features();
197 let dim = features.first().map(|f| f.len()).unwrap_or(0);
198
199 let path = output_dir.join("node_features.npy");
200 npy_writer::write_npy_2d_f64(&path, &features)?;
201
202 Ok(dim)
203 }
204
205 fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
207 let features = graph.edge_features();
208 let dim = features.first().map(|f| f.len()).unwrap_or(0);
209
210 let path = output_dir.join("edge_features.npy");
211 npy_writer::write_npy_2d_f64(&path, &features)?;
212
213 Ok(dim)
214 }
215
216 fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
218 let labels: Vec<i64> = graph
219 .node_anomaly_mask()
220 .iter()
221 .map(|&b| if b { 1 } else { 0 })
222 .collect();
223
224 let path = output_dir.join("node_labels.npy");
225 npy_writer::write_npy_1d_i64(&path, &labels)?;
226
227 Ok(())
228 }
229
230 fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
232 let labels: Vec<i64> = graph
233 .edge_anomaly_mask()
234 .iter()
235 .map(|&b| if b { 1 } else { 0 })
236 .collect();
237
238 let path = output_dir.join("edge_labels.npy");
239 npy_writer::write_npy_1d_i64(&path, &labels)?;
240
241 Ok(())
242 }
243
244 fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
246 npy_writer::export_masks(
247 output_dir,
248 graph.node_count(),
249 self.config.common.seed,
250 self.config.common.train_ratio,
251 self.config.common.val_ratio,
252 )
253 }
254
255 fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
257 let script = r#"#!/usr/bin/env python3
258"""
259PyTorch Geometric Data Loader
260
261Auto-generated loader for graph data exported from synth-graph.
262"""
263
264import json
265import numpy as np
266import torch
267from pathlib import Path
268
269try:
270 from torch_geometric.data import Data
271 HAS_PYG = True
272except ImportError:
273 HAS_PYG = False
274 print("Warning: torch_geometric not installed. Install with: pip install torch-geometric")
275
276
277def load_graph(data_dir: str = ".") -> "Data":
278 """Load graph data into a PyTorch Geometric Data object."""
279 data_dir = Path(data_dir)
280
281 # Load metadata
282 with open(data_dir / "metadata.json") as f:
283 metadata = json.load(f)
284
285 # Load edge index
286 edge_index = torch.from_numpy(np.load(data_dir / "edge_index.npy")).long()
287
288 # Load node features (if available)
289 x = None
290 node_features_path = data_dir / "node_features.npy"
291 if node_features_path.exists():
292 x = torch.from_numpy(np.load(node_features_path)).float()
293
294 # Load edge features (if available)
295 edge_attr = None
296 edge_features_path = data_dir / "edge_features.npy"
297 if edge_features_path.exists():
298 edge_attr = torch.from_numpy(np.load(edge_features_path)).float()
299
300 # Load node labels (if available)
301 y = None
302 node_labels_path = data_dir / "node_labels.npy"
303 if node_labels_path.exists():
304 y = torch.from_numpy(np.load(node_labels_path)).long()
305
306 # Load masks (if available)
307 train_mask = None
308 val_mask = None
309 test_mask = None
310
311 if (data_dir / "train_mask.npy").exists():
312 train_mask = torch.from_numpy(np.load(data_dir / "train_mask.npy")).bool()
313 if (data_dir / "val_mask.npy").exists():
314 val_mask = torch.from_numpy(np.load(data_dir / "val_mask.npy")).bool()
315 if (data_dir / "test_mask.npy").exists():
316 test_mask = torch.from_numpy(np.load(data_dir / "test_mask.npy")).bool()
317
318 if not HAS_PYG:
319 return {
320 "edge_index": edge_index,
321 "x": x,
322 "edge_attr": edge_attr,
323 "y": y,
324 "train_mask": train_mask,
325 "val_mask": val_mask,
326 "test_mask": test_mask,
327 "metadata": metadata,
328 }
329
330 # Create PyG Data object
331 data = Data(
332 x=x,
333 edge_index=edge_index,
334 edge_attr=edge_attr,
335 y=y,
336 train_mask=train_mask,
337 val_mask=val_mask,
338 test_mask=test_mask,
339 )
340
341 # Store metadata
342 data.metadata = metadata
343
344 return data
345
346
347def print_summary(data_dir: str = "."):
348 """Print summary of the graph data."""
349 data = load_graph(data_dir)
350
351 if isinstance(data, dict):
352 metadata = data["metadata"]
353 print(f"Graph: {metadata['name']}")
354 print(f"Nodes: {metadata['num_nodes']}")
355 print(f"Edges: {metadata['num_edges']}")
356 print(f"Node features: {data['x'].shape if data['x'] is not None else 'None'}")
357 print(f"Edge features: {data['edge_attr'].shape if data['edge_attr'] is not None else 'None'}")
358 else:
359 print(f"Graph: {data.metadata['name']}")
360 print(f"Nodes: {data.num_nodes}")
361 print(f"Edges: {data.num_edges}")
362 print(f"Node features: {data.x.shape if data.x is not None else 'None'}")
363 print(f"Edge features: {data.edge_attr.shape if data.edge_attr is not None else 'None'}")
364 if data.y is not None:
365 print(f"Anomalous nodes: {data.y.sum().item()}")
366 if data.train_mask is not None:
367 print(f"Train/Val/Test: {data.train_mask.sum()}/{data.val_mask.sum()}/{data.test_mask.sum()}")
368
369
370if __name__ == "__main__":
371 import sys
372 data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
373 print_summary(data_dir)
374"#;
375
376 let path = output_dir.join("load_graph.py");
377 let mut file = File::create(path)?;
378 file.write_all(script.as_bytes())?;
379
380 Ok(())
381 }
382}
383
384#[cfg(test)]
385#[allow(clippy::unwrap_used)]
386mod tests {
387 use super::*;
388 use crate::test_helpers::create_test_graph;
389 use tempfile::tempdir;
390
391 #[test]
392 fn test_pyg_export() {
393 let graph = create_test_graph();
394 let dir = tempdir().unwrap();
395
396 let exporter = PyGExporter::new(PyGExportConfig::default());
397 let metadata = exporter.export(&graph, dir.path()).unwrap();
398
399 assert_eq!(metadata.num_nodes, 2);
400 assert_eq!(metadata.num_edges, 1);
401 assert!(dir.path().join("edge_index.npy").exists());
402 assert!(dir.path().join("node_features.npy").exists());
403 assert!(dir.path().join("metadata.json").exists());
404 assert!(dir.path().join("load_graph.py").exists());
405 }
406}