datasynth_graph/exporters/
pytorch_geometric.rs1use std::collections::HashMap;
22use std::fs::{self, File};
23use std::io::Write;
24use std::path::Path;
25
26use crate::exporters::common::{CommonExportConfig, CommonGraphMetadata};
27use crate::exporters::npy_writer;
28use crate::models::Graph;
29
30#[derive(Debug, Clone, Default)]
32pub struct PyGExportConfig {
33 pub common: CommonExportConfig,
35 pub one_hot_categoricals: bool,
37}
38
39pub type PyGMetadata = CommonGraphMetadata;
41
42pub struct PyGExporter {
44 config: PyGExportConfig,
45}
46
47impl PyGExporter {
48 pub fn new(config: PyGExportConfig) -> Self {
50 Self { config }
51 }
52
53 pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<PyGMetadata> {
55 fs::create_dir_all(output_dir)?;
56
57 let mut files = Vec::new();
58 let mut statistics = HashMap::new();
59
60 self.export_edge_index(graph, output_dir)?;
62 files.push("edge_index.npy".to_string());
63
64 if self.config.common.export_node_features {
66 let dim = self.export_node_features(graph, output_dir)?;
67 files.push("node_features.npy".to_string());
68 statistics.insert("node_feature_dim".to_string(), dim as f64);
69 }
70
71 if self.config.common.export_edge_features {
73 let dim = self.export_edge_features(graph, output_dir)?;
74 files.push("edge_features.npy".to_string());
75 statistics.insert("edge_feature_dim".to_string(), dim as f64);
76 }
77
78 if self.config.common.export_node_labels {
80 self.export_node_labels(graph, output_dir)?;
81 files.push("node_labels.npy".to_string());
82 }
83
84 if self.config.common.export_edge_labels {
86 self.export_edge_labels(graph, output_dir)?;
87 files.push("edge_labels.npy".to_string());
88 }
89
90 if self.config.common.export_masks {
92 self.export_masks(graph, output_dir)?;
93 files.push("train_mask.npy".to_string());
94 files.push("val_mask.npy".to_string());
95 files.push("test_mask.npy".to_string());
96 }
97
98 let node_types: HashMap<String, usize> = graph
100 .nodes_by_type
101 .keys()
102 .enumerate()
103 .map(|(i, t)| (t.as_str().to_string(), i))
104 .collect();
105
106 let edge_types: HashMap<String, usize> = graph
107 .edges_by_type
108 .keys()
109 .enumerate()
110 .map(|(i, t)| (t.as_str().to_string(), i))
111 .collect();
112
113 statistics.insert("density".to_string(), graph.metadata.density);
115 statistics.insert(
116 "anomalous_node_ratio".to_string(),
117 graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
118 );
119 statistics.insert(
120 "anomalous_edge_ratio".to_string(),
121 graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
122 );
123
124 let metadata = PyGMetadata {
126 name: graph.name.clone(),
127 num_nodes: graph.node_count(),
128 num_edges: graph.edge_count(),
129 node_feature_dim: graph.metadata.node_feature_dim,
130 edge_feature_dim: graph.metadata.edge_feature_dim,
131 num_node_classes: 2, num_edge_classes: 2,
133 node_types,
134 edge_types,
135 is_directed: true,
136 files,
137 statistics,
138 };
139
140 let metadata_path = output_dir.join("metadata.json");
142 let file = File::create(metadata_path)?;
143 serde_json::to_writer_pretty(file, &metadata)?;
144
145 self.write_loader_script(output_dir)?;
147
148 Ok(metadata)
149 }
150
151 fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
153 let (sources, targets) = graph.edge_index();
154
155 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
157 node_ids.sort();
158 let id_to_idx: HashMap<_, _> = node_ids
159 .iter()
160 .enumerate()
161 .map(|(i, &id)| (id, i))
162 .collect();
163
164 let sources_remapped: Vec<i64> = sources
166 .iter()
167 .map(|id| *id_to_idx.get(id).unwrap_or(&0) as i64)
168 .collect();
169 let targets_remapped: Vec<i64> = targets
170 .iter()
171 .map(|id| *id_to_idx.get(id).unwrap_or(&0) as i64)
172 .collect();
173
174 let path = output_dir.join("edge_index.npy");
176 npy_writer::write_npy_2d_i64(&path, &[sources_remapped, targets_remapped])?;
177
178 Ok(())
179 }
180
181 fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
183 let features = graph.node_features();
184 let dim = features.first().map(|f| f.len()).unwrap_or(0);
185
186 let path = output_dir.join("node_features.npy");
187 npy_writer::write_npy_2d_f64(&path, &features)?;
188
189 Ok(dim)
190 }
191
192 fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
194 let features = graph.edge_features();
195 let dim = features.first().map(|f| f.len()).unwrap_or(0);
196
197 let path = output_dir.join("edge_features.npy");
198 npy_writer::write_npy_2d_f64(&path, &features)?;
199
200 Ok(dim)
201 }
202
203 fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
205 let labels: Vec<i64> = graph
206 .node_anomaly_mask()
207 .iter()
208 .map(|&b| if b { 1 } else { 0 })
209 .collect();
210
211 let path = output_dir.join("node_labels.npy");
212 npy_writer::write_npy_1d_i64(&path, &labels)?;
213
214 Ok(())
215 }
216
217 fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
219 let labels: Vec<i64> = graph
220 .edge_anomaly_mask()
221 .iter()
222 .map(|&b| if b { 1 } else { 0 })
223 .collect();
224
225 let path = output_dir.join("edge_labels.npy");
226 npy_writer::write_npy_1d_i64(&path, &labels)?;
227
228 Ok(())
229 }
230
231 fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
233 npy_writer::export_masks(
234 output_dir,
235 graph.node_count(),
236 self.config.common.seed,
237 self.config.common.train_ratio,
238 self.config.common.val_ratio,
239 )
240 }
241
242 fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
244 let script = r#"#!/usr/bin/env python3
245"""
246PyTorch Geometric Data Loader
247
248Auto-generated loader for graph data exported from synth-graph.
249"""
250
251import json
252import numpy as np
253import torch
254from pathlib import Path
255
256try:
257 from torch_geometric.data import Data
258 HAS_PYG = True
259except ImportError:
260 HAS_PYG = False
261 print("Warning: torch_geometric not installed. Install with: pip install torch-geometric")
262
263
264def load_graph(data_dir: str = ".") -> "Data":
265 """Load graph data into a PyTorch Geometric Data object."""
266 data_dir = Path(data_dir)
267
268 # Load metadata
269 with open(data_dir / "metadata.json") as f:
270 metadata = json.load(f)
271
272 # Load edge index
273 edge_index = torch.from_numpy(np.load(data_dir / "edge_index.npy")).long()
274
275 # Load node features (if available)
276 x = None
277 node_features_path = data_dir / "node_features.npy"
278 if node_features_path.exists():
279 x = torch.from_numpy(np.load(node_features_path)).float()
280
281 # Load edge features (if available)
282 edge_attr = None
283 edge_features_path = data_dir / "edge_features.npy"
284 if edge_features_path.exists():
285 edge_attr = torch.from_numpy(np.load(edge_features_path)).float()
286
287 # Load node labels (if available)
288 y = None
289 node_labels_path = data_dir / "node_labels.npy"
290 if node_labels_path.exists():
291 y = torch.from_numpy(np.load(node_labels_path)).long()
292
293 # Load masks (if available)
294 train_mask = None
295 val_mask = None
296 test_mask = None
297
298 if (data_dir / "train_mask.npy").exists():
299 train_mask = torch.from_numpy(np.load(data_dir / "train_mask.npy")).bool()
300 if (data_dir / "val_mask.npy").exists():
301 val_mask = torch.from_numpy(np.load(data_dir / "val_mask.npy")).bool()
302 if (data_dir / "test_mask.npy").exists():
303 test_mask = torch.from_numpy(np.load(data_dir / "test_mask.npy")).bool()
304
305 if not HAS_PYG:
306 return {
307 "edge_index": edge_index,
308 "x": x,
309 "edge_attr": edge_attr,
310 "y": y,
311 "train_mask": train_mask,
312 "val_mask": val_mask,
313 "test_mask": test_mask,
314 "metadata": metadata,
315 }
316
317 # Create PyG Data object
318 data = Data(
319 x=x,
320 edge_index=edge_index,
321 edge_attr=edge_attr,
322 y=y,
323 train_mask=train_mask,
324 val_mask=val_mask,
325 test_mask=test_mask,
326 )
327
328 # Store metadata
329 data.metadata = metadata
330
331 return data
332
333
334def print_summary(data_dir: str = "."):
335 """Print summary of the graph data."""
336 data = load_graph(data_dir)
337
338 if isinstance(data, dict):
339 metadata = data["metadata"]
340 print(f"Graph: {metadata['name']}")
341 print(f"Nodes: {metadata['num_nodes']}")
342 print(f"Edges: {metadata['num_edges']}")
343 print(f"Node features: {data['x'].shape if data['x'] is not None else 'None'}")
344 print(f"Edge features: {data['edge_attr'].shape if data['edge_attr'] is not None else 'None'}")
345 else:
346 print(f"Graph: {data.metadata['name']}")
347 print(f"Nodes: {data.num_nodes}")
348 print(f"Edges: {data.num_edges}")
349 print(f"Node features: {data.x.shape if data.x is not None else 'None'}")
350 print(f"Edge features: {data.edge_attr.shape if data.edge_attr is not None else 'None'}")
351 if data.y is not None:
352 print(f"Anomalous nodes: {data.y.sum().item()}")
353 if data.train_mask is not None:
354 print(f"Train/Val/Test: {data.train_mask.sum()}/{data.val_mask.sum()}/{data.test_mask.sum()}")
355
356
357if __name__ == "__main__":
358 import sys
359 data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
360 print_summary(data_dir)
361"#;
362
363 let path = output_dir.join("load_graph.py");
364 let mut file = File::create(path)?;
365 file.write_all(script.as_bytes())?;
366
367 Ok(())
368 }
369}
370
371#[cfg(test)]
372#[allow(clippy::unwrap_used)]
373mod tests {
374 use super::*;
375 use crate::test_helpers::create_test_graph;
376 use tempfile::tempdir;
377
378 #[test]
379 fn test_pyg_export() {
380 let graph = create_test_graph();
381 let dir = tempdir().unwrap();
382
383 let exporter = PyGExporter::new(PyGExportConfig::default());
384 let metadata = exporter.export(&graph, dir.path()).unwrap();
385
386 assert_eq!(metadata.num_nodes, 2);
387 assert_eq!(metadata.num_edges, 1);
388 assert!(dir.path().join("edge_index.npy").exists());
389 assert!(dir.path().join("node_features.npy").exists());
390 assert!(dir.path().join("metadata.json").exists());
391 assert!(dir.path().join("load_graph.py").exists());
392 }
393}