1use std::collections::HashMap;
22use std::fs::{self, File};
23use std::io::{BufWriter, Write};
24use std::path::Path;
25
26use serde::{Deserialize, Serialize};
27
28use crate::models::Graph;
29
30#[derive(Debug, Clone)]
32pub struct PyGExportConfig {
33 pub export_node_features: bool,
35 pub export_edge_features: bool,
37 pub export_node_labels: bool,
39 pub export_edge_labels: bool,
41 pub one_hot_categoricals: bool,
43 pub export_masks: bool,
45 pub train_ratio: f64,
47 pub val_ratio: f64,
49 pub seed: u64,
51}
52
53impl Default for PyGExportConfig {
54 fn default() -> Self {
55 Self {
56 export_node_features: true,
57 export_edge_features: true,
58 export_node_labels: true,
59 export_edge_labels: true,
60 one_hot_categoricals: false,
61 export_masks: true,
62 train_ratio: 0.7,
63 val_ratio: 0.15,
64 seed: 42,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct PyGMetadata {
72 pub name: String,
74 pub num_nodes: usize,
76 pub num_edges: usize,
78 pub node_feature_dim: usize,
80 pub edge_feature_dim: usize,
82 pub num_node_classes: usize,
84 pub num_edge_classes: usize,
86 pub node_types: HashMap<String, usize>,
88 pub edge_types: HashMap<String, usize>,
90 pub is_directed: bool,
92 pub files: Vec<String>,
94 pub statistics: HashMap<String, f64>,
96}
97
98pub struct PyGExporter {
100 config: PyGExportConfig,
101}
102
103impl PyGExporter {
104 pub fn new(config: PyGExportConfig) -> Self {
106 Self { config }
107 }
108
109 pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<PyGMetadata> {
111 fs::create_dir_all(output_dir)?;
112
113 let mut files = Vec::new();
114 let mut statistics = HashMap::new();
115
116 self.export_edge_index(graph, output_dir)?;
118 files.push("edge_index.npy".to_string());
119
120 if self.config.export_node_features {
122 let dim = self.export_node_features(graph, output_dir)?;
123 files.push("node_features.npy".to_string());
124 statistics.insert("node_feature_dim".to_string(), dim as f64);
125 }
126
127 if self.config.export_edge_features {
129 let dim = self.export_edge_features(graph, output_dir)?;
130 files.push("edge_features.npy".to_string());
131 statistics.insert("edge_feature_dim".to_string(), dim as f64);
132 }
133
134 if self.config.export_node_labels {
136 self.export_node_labels(graph, output_dir)?;
137 files.push("node_labels.npy".to_string());
138 }
139
140 if self.config.export_edge_labels {
142 self.export_edge_labels(graph, output_dir)?;
143 files.push("edge_labels.npy".to_string());
144 }
145
146 if self.config.export_masks {
148 self.export_masks(graph, output_dir)?;
149 files.push("train_mask.npy".to_string());
150 files.push("val_mask.npy".to_string());
151 files.push("test_mask.npy".to_string());
152 }
153
154 let node_types: HashMap<String, usize> = graph
156 .nodes_by_type
157 .keys()
158 .enumerate()
159 .map(|(i, t)| (t.as_str().to_string(), i))
160 .collect();
161
162 let edge_types: HashMap<String, usize> = graph
163 .edges_by_type
164 .keys()
165 .enumerate()
166 .map(|(i, t)| (t.as_str().to_string(), i))
167 .collect();
168
169 statistics.insert("density".to_string(), graph.metadata.density);
171 statistics.insert(
172 "anomalous_node_ratio".to_string(),
173 graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
174 );
175 statistics.insert(
176 "anomalous_edge_ratio".to_string(),
177 graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
178 );
179
180 let metadata = PyGMetadata {
182 name: graph.name.clone(),
183 num_nodes: graph.node_count(),
184 num_edges: graph.edge_count(),
185 node_feature_dim: graph.metadata.node_feature_dim,
186 edge_feature_dim: graph.metadata.edge_feature_dim,
187 num_node_classes: 2, num_edge_classes: 2,
189 node_types,
190 edge_types,
191 is_directed: true,
192 files,
193 statistics,
194 };
195
196 let metadata_path = output_dir.join("metadata.json");
198 let file = File::create(metadata_path)?;
199 serde_json::to_writer_pretty(file, &metadata)?;
200
201 self.write_loader_script(output_dir)?;
203
204 Ok(metadata)
205 }
206
207 fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
209 let (sources, targets) = graph.edge_index();
210
211 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
213 node_ids.sort();
214 let id_to_idx: HashMap<_, _> = node_ids
215 .iter()
216 .enumerate()
217 .map(|(i, &id)| (id, i))
218 .collect();
219
220 let sources_remapped: Vec<i64> = sources
222 .iter()
223 .map(|id| *id_to_idx.get(id).unwrap_or(&0) as i64)
224 .collect();
225 let targets_remapped: Vec<i64> = targets
226 .iter()
227 .map(|id| *id_to_idx.get(id).unwrap_or(&0) as i64)
228 .collect();
229
230 let path = output_dir.join("edge_index.npy");
232 self.write_npy_2d_i64(&path, &[sources_remapped, targets_remapped])?;
233
234 Ok(())
235 }
236
237 fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
239 let features = graph.node_features();
240 let dim = features.first().map(|f| f.len()).unwrap_or(0);
241
242 let path = output_dir.join("node_features.npy");
243 self.write_npy_2d_f64(&path, &features)?;
244
245 Ok(dim)
246 }
247
248 fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
250 let features = graph.edge_features();
251 let dim = features.first().map(|f| f.len()).unwrap_or(0);
252
253 let path = output_dir.join("edge_features.npy");
254 self.write_npy_2d_f64(&path, &features)?;
255
256 Ok(dim)
257 }
258
259 fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
261 let labels: Vec<i64> = graph
262 .node_anomaly_mask()
263 .iter()
264 .map(|&b| if b { 1 } else { 0 })
265 .collect();
266
267 let path = output_dir.join("node_labels.npy");
268 self.write_npy_1d_i64(&path, &labels)?;
269
270 Ok(())
271 }
272
273 fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
275 let labels: Vec<i64> = graph
276 .edge_anomaly_mask()
277 .iter()
278 .map(|&b| if b { 1 } else { 0 })
279 .collect();
280
281 let path = output_dir.join("edge_labels.npy");
282 self.write_npy_1d_i64(&path, &labels)?;
283
284 Ok(())
285 }
286
287 fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
289 let n = graph.node_count();
290 let mut rng = SimpleRng::new(self.config.seed);
291
292 let train_size = (n as f64 * self.config.train_ratio) as usize;
293 let val_size = (n as f64 * self.config.val_ratio) as usize;
294
295 let mut indices: Vec<usize> = (0..n).collect();
297 for i in (1..n).rev() {
298 let j = (rng.next() % (i as u64 + 1)) as usize;
299 indices.swap(i, j);
300 }
301
302 let mut train_mask = vec![false; n];
304 let mut val_mask = vec![false; n];
305 let mut test_mask = vec![false; n];
306
307 for (i, &idx) in indices.iter().enumerate() {
308 if i < train_size {
309 train_mask[idx] = true;
310 } else if i < train_size + val_size {
311 val_mask[idx] = true;
312 } else {
313 test_mask[idx] = true;
314 }
315 }
316
317 self.write_npy_1d_bool(&output_dir.join("train_mask.npy"), &train_mask)?;
319 self.write_npy_1d_bool(&output_dir.join("val_mask.npy"), &val_mask)?;
320 self.write_npy_1d_bool(&output_dir.join("test_mask.npy"), &test_mask)?;
321
322 Ok(())
323 }
324
325 fn write_npy_1d_i64(&self, path: &Path, data: &[i64]) -> std::io::Result<()> {
327 let file = File::create(path)?;
328 let mut writer = BufWriter::new(file);
329
330 let shape = format!("({},)", data.len());
332 self.write_npy_header(&mut writer, "<i8", &shape)?;
333
334 for &val in data {
336 writer.write_all(&val.to_le_bytes())?;
337 }
338
339 Ok(())
340 }
341
342 fn write_npy_1d_bool(&self, path: &Path, data: &[bool]) -> std::io::Result<()> {
344 let file = File::create(path)?;
345 let mut writer = BufWriter::new(file);
346
347 let shape = format!("({},)", data.len());
349 self.write_npy_header(&mut writer, "|b1", &shape)?;
350
351 for &val in data {
353 writer.write_all(&[if val { 1u8 } else { 0u8 }])?;
354 }
355
356 Ok(())
357 }
358
359 fn write_npy_2d_i64(&self, path: &Path, data: &[Vec<i64>]) -> std::io::Result<()> {
361 let file = File::create(path)?;
362 let mut writer = BufWriter::new(file);
363
364 let rows = data.len();
365 let cols = data.first().map(|r| r.len()).unwrap_or(0);
366
367 let shape = format!("({}, {})", rows, cols);
369 self.write_npy_header(&mut writer, "<i8", &shape)?;
370
371 for row in data {
373 for &val in row {
374 writer.write_all(&val.to_le_bytes())?;
375 }
376 }
377
378 Ok(())
379 }
380
381 fn write_npy_2d_f64(&self, path: &Path, data: &[Vec<f64>]) -> std::io::Result<()> {
383 let file = File::create(path)?;
384 let mut writer = BufWriter::new(file);
385
386 let rows = data.len();
387 let cols = data.first().map(|r| r.len()).unwrap_or(0);
388
389 let shape = format!("({}, {})", rows, cols);
391 self.write_npy_header(&mut writer, "<f8", &shape)?;
392
393 for row in data {
395 for &val in row {
396 writer.write_all(&val.to_le_bytes())?;
397 }
398 for _ in row.len()..cols {
400 writer.write_all(&0.0_f64.to_le_bytes())?;
401 }
402 }
403
404 Ok(())
405 }
406
407 fn write_npy_header<W: Write>(
409 &self,
410 writer: &mut W,
411 dtype: &str,
412 shape: &str,
413 ) -> std::io::Result<()> {
414 writer.write_all(&[0x93])?; writer.write_all(b"NUMPY")?;
417 writer.write_all(&[0x01, 0x00])?; let header = format!(
421 "{{'descr': '{}', 'fortran_order': False, 'shape': {} }}",
422 dtype, shape
423 );
424
425 let header_len = header.len();
427 let total_len = 10 + header_len + 1; let padding = (64 - (total_len % 64)) % 64;
429 let padded_len = header_len + 1 + padding;
430
431 writer.write_all(&(padded_len as u16).to_le_bytes())?;
432 writer.write_all(header.as_bytes())?;
433 for _ in 0..padding {
434 writer.write_all(b" ")?;
435 }
436 writer.write_all(b"\n")?;
437
438 Ok(())
439 }
440
441 fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
443 let script = r#"#!/usr/bin/env python3
444"""
445PyTorch Geometric Data Loader
446
447Auto-generated loader for graph data exported from synth-graph.
448"""
449
450import json
451import numpy as np
452import torch
453from pathlib import Path
454
455try:
456 from torch_geometric.data import Data
457 HAS_PYG = True
458except ImportError:
459 HAS_PYG = False
460 print("Warning: torch_geometric not installed. Install with: pip install torch-geometric")
461
462
463def load_graph(data_dir: str = ".") -> "Data":
464 """Load graph data into a PyTorch Geometric Data object."""
465 data_dir = Path(data_dir)
466
467 # Load metadata
468 with open(data_dir / "metadata.json") as f:
469 metadata = json.load(f)
470
471 # Load edge index
472 edge_index = torch.from_numpy(np.load(data_dir / "edge_index.npy")).long()
473
474 # Load node features (if available)
475 x = None
476 node_features_path = data_dir / "node_features.npy"
477 if node_features_path.exists():
478 x = torch.from_numpy(np.load(node_features_path)).float()
479
480 # Load edge features (if available)
481 edge_attr = None
482 edge_features_path = data_dir / "edge_features.npy"
483 if edge_features_path.exists():
484 edge_attr = torch.from_numpy(np.load(edge_features_path)).float()
485
486 # Load node labels (if available)
487 y = None
488 node_labels_path = data_dir / "node_labels.npy"
489 if node_labels_path.exists():
490 y = torch.from_numpy(np.load(node_labels_path)).long()
491
492 # Load masks (if available)
493 train_mask = None
494 val_mask = None
495 test_mask = None
496
497 if (data_dir / "train_mask.npy").exists():
498 train_mask = torch.from_numpy(np.load(data_dir / "train_mask.npy")).bool()
499 if (data_dir / "val_mask.npy").exists():
500 val_mask = torch.from_numpy(np.load(data_dir / "val_mask.npy")).bool()
501 if (data_dir / "test_mask.npy").exists():
502 test_mask = torch.from_numpy(np.load(data_dir / "test_mask.npy")).bool()
503
504 if not HAS_PYG:
505 return {
506 "edge_index": edge_index,
507 "x": x,
508 "edge_attr": edge_attr,
509 "y": y,
510 "train_mask": train_mask,
511 "val_mask": val_mask,
512 "test_mask": test_mask,
513 "metadata": metadata,
514 }
515
516 # Create PyG Data object
517 data = Data(
518 x=x,
519 edge_index=edge_index,
520 edge_attr=edge_attr,
521 y=y,
522 train_mask=train_mask,
523 val_mask=val_mask,
524 test_mask=test_mask,
525 )
526
527 # Store metadata
528 data.metadata = metadata
529
530 return data
531
532
533def print_summary(data_dir: str = "."):
534 """Print summary of the graph data."""
535 data = load_graph(data_dir)
536
537 if isinstance(data, dict):
538 metadata = data["metadata"]
539 print(f"Graph: {metadata['name']}")
540 print(f"Nodes: {metadata['num_nodes']}")
541 print(f"Edges: {metadata['num_edges']}")
542 print(f"Node features: {data['x'].shape if data['x'] is not None else 'None'}")
543 print(f"Edge features: {data['edge_attr'].shape if data['edge_attr'] is not None else 'None'}")
544 else:
545 print(f"Graph: {data.metadata['name']}")
546 print(f"Nodes: {data.num_nodes}")
547 print(f"Edges: {data.num_edges}")
548 print(f"Node features: {data.x.shape if data.x is not None else 'None'}")
549 print(f"Edge features: {data.edge_attr.shape if data.edge_attr is not None else 'None'}")
550 if data.y is not None:
551 print(f"Anomalous nodes: {data.y.sum().item()}")
552 if data.train_mask is not None:
553 print(f"Train/Val/Test: {data.train_mask.sum()}/{data.val_mask.sum()}/{data.test_mask.sum()}")
554
555
556if __name__ == "__main__":
557 import sys
558 data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
559 print_summary(data_dir)
560"#;
561
562 let path = output_dir.join("load_graph.py");
563 let mut file = File::create(path)?;
564 file.write_all(script.as_bytes())?;
565
566 Ok(())
567 }
568}
569
570struct SimpleRng {
572 state: u64,
573}
574
575impl SimpleRng {
576 fn new(seed: u64) -> Self {
577 Self {
578 state: if seed == 0 { 1 } else { seed },
579 }
580 }
581
582 fn next(&mut self) -> u64 {
583 let mut x = self.state;
584 x ^= x << 13;
585 x ^= x >> 7;
586 x ^= x << 17;
587 self.state = x;
588 x
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595 use crate::models::{EdgeType, GraphEdge, GraphNode, GraphType, NodeType};
596 use tempfile::tempdir;
597
598 fn create_test_graph() -> Graph {
599 let mut graph = Graph::new("test", GraphType::Transaction);
600
601 let n1 = graph.add_node(
602 GraphNode::new(0, NodeType::Account, "1000".to_string(), "Cash".to_string())
603 .with_feature(0.5),
604 );
605 let n2 = graph.add_node(
606 GraphNode::new(0, NodeType::Account, "2000".to_string(), "AP".to_string())
607 .with_feature(0.8),
608 );
609
610 graph.add_edge(
611 GraphEdge::new(0, n1, n2, EdgeType::Transaction)
612 .with_weight(1000.0)
613 .with_feature(6.9),
614 );
615
616 graph.compute_statistics();
617 graph
618 }
619
620 #[test]
621 fn test_pyg_export() {
622 let graph = create_test_graph();
623 let dir = tempdir().unwrap();
624
625 let exporter = PyGExporter::new(PyGExportConfig::default());
626 let metadata = exporter.export(&graph, dir.path()).unwrap();
627
628 assert_eq!(metadata.num_nodes, 2);
629 assert_eq!(metadata.num_edges, 1);
630 assert!(dir.path().join("edge_index.npy").exists());
631 assert!(dir.path().join("node_features.npy").exists());
632 assert!(dir.path().join("metadata.json").exists());
633 assert!(dir.path().join("load_graph.py").exists());
634 }
635}