datasynth_graph/exporters/
pytorch_geometric.rs

1//! PyTorch Geometric exporter.
2//!
3//! Exports graph data in formats compatible with PyTorch Geometric:
4//! - NumPy arrays (.npy) for easy Python loading
5//! - JSON metadata for graph information
6//!
7//! The exported data can be loaded in Python with:
8//! ```python
9//! import numpy as np
10//! import torch
11//! from torch_geometric.data import Data
12//!
13//! node_features = torch.from_numpy(np.load('node_features.npy'))
14//! edge_index = torch.from_numpy(np.load('edge_index.npy'))
15//! edge_attr = torch.from_numpy(np.load('edge_attr.npy'))
16//! y = torch.from_numpy(np.load('labels.npy'))
17//!
18//! data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr, y=y)
19//! ```
20
21use std::collections::HashMap;
22use std::fs::{self, File};
23use std::io::{BufWriter, Write};
24use std::path::Path;
25
26use serde::{Deserialize, Serialize};
27
28use crate::models::Graph;
29
30/// Configuration for PyTorch Geometric export.
31#[derive(Debug, Clone)]
32pub struct PyGExportConfig {
33    /// Export node features.
34    pub export_node_features: bool,
35    /// Export edge features.
36    pub export_edge_features: bool,
37    /// Export node labels (anomaly flags).
38    pub export_node_labels: bool,
39    /// Export edge labels (anomaly flags).
40    pub export_edge_labels: bool,
41    /// Export categorical features as one-hot.
42    pub one_hot_categoricals: bool,
43    /// Export train/val/test masks.
44    pub export_masks: bool,
45    /// Train split ratio.
46    pub train_ratio: f64,
47    /// Validation split ratio.
48    pub val_ratio: f64,
49    /// Random seed for splits.
50    pub seed: u64,
51}
52
53impl Default for PyGExportConfig {
54    fn default() -> Self {
55        Self {
56            export_node_features: true,
57            export_edge_features: true,
58            export_node_labels: true,
59            export_edge_labels: true,
60            one_hot_categoricals: false,
61            export_masks: true,
62            train_ratio: 0.7,
63            val_ratio: 0.15,
64            seed: 42,
65        }
66    }
67}
68
69/// Metadata about the exported PyG data.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct PyGMetadata {
72    /// Graph name.
73    pub name: String,
74    /// Number of nodes.
75    pub num_nodes: usize,
76    /// Number of edges.
77    pub num_edges: usize,
78    /// Node feature dimension.
79    pub node_feature_dim: usize,
80    /// Edge feature dimension.
81    pub edge_feature_dim: usize,
82    /// Number of node classes (for classification).
83    pub num_node_classes: usize,
84    /// Number of edge classes (for classification).
85    pub num_edge_classes: usize,
86    /// Node type mapping.
87    pub node_types: HashMap<String, usize>,
88    /// Edge type mapping.
89    pub edge_types: HashMap<String, usize>,
90    /// Whether graph is directed.
91    pub is_directed: bool,
92    /// Files included in export.
93    pub files: Vec<String>,
94    /// Additional statistics.
95    pub statistics: HashMap<String, f64>,
96}
97
98/// PyTorch Geometric exporter.
99pub struct PyGExporter {
100    config: PyGExportConfig,
101}
102
103impl PyGExporter {
104    /// Creates a new PyG exporter.
105    pub fn new(config: PyGExportConfig) -> Self {
106        Self { config }
107    }
108
109    /// Exports a graph to PyTorch Geometric format.
110    pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<PyGMetadata> {
111        fs::create_dir_all(output_dir)?;
112
113        let mut files = Vec::new();
114        let mut statistics = HashMap::new();
115
116        // Export edge index
117        self.export_edge_index(graph, output_dir)?;
118        files.push("edge_index.npy".to_string());
119
120        // Export node features
121        if self.config.export_node_features {
122            let dim = self.export_node_features(graph, output_dir)?;
123            files.push("node_features.npy".to_string());
124            statistics.insert("node_feature_dim".to_string(), dim as f64);
125        }
126
127        // Export edge features
128        if self.config.export_edge_features {
129            let dim = self.export_edge_features(graph, output_dir)?;
130            files.push("edge_features.npy".to_string());
131            statistics.insert("edge_feature_dim".to_string(), dim as f64);
132        }
133
134        // Export node labels
135        if self.config.export_node_labels {
136            self.export_node_labels(graph, output_dir)?;
137            files.push("node_labels.npy".to_string());
138        }
139
140        // Export edge labels
141        if self.config.export_edge_labels {
142            self.export_edge_labels(graph, output_dir)?;
143            files.push("edge_labels.npy".to_string());
144        }
145
146        // Export masks
147        if self.config.export_masks {
148            self.export_masks(graph, output_dir)?;
149            files.push("train_mask.npy".to_string());
150            files.push("val_mask.npy".to_string());
151            files.push("test_mask.npy".to_string());
152        }
153
154        // Compute node/edge type mappings
155        let node_types: HashMap<String, usize> = graph
156            .nodes_by_type
157            .keys()
158            .enumerate()
159            .map(|(i, t)| (t.as_str().to_string(), i))
160            .collect();
161
162        let edge_types: HashMap<String, usize> = graph
163            .edges_by_type
164            .keys()
165            .enumerate()
166            .map(|(i, t)| (t.as_str().to_string(), i))
167            .collect();
168
169        // Compute statistics
170        statistics.insert("density".to_string(), graph.metadata.density);
171        statistics.insert(
172            "anomalous_node_ratio".to_string(),
173            graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
174        );
175        statistics.insert(
176            "anomalous_edge_ratio".to_string(),
177            graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
178        );
179
180        // Create metadata
181        let metadata = PyGMetadata {
182            name: graph.name.clone(),
183            num_nodes: graph.node_count(),
184            num_edges: graph.edge_count(),
185            node_feature_dim: graph.metadata.node_feature_dim,
186            edge_feature_dim: graph.metadata.edge_feature_dim,
187            num_node_classes: 2, // Normal/Anomaly
188            num_edge_classes: 2,
189            node_types,
190            edge_types,
191            is_directed: true,
192            files,
193            statistics,
194        };
195
196        // Write metadata
197        let metadata_path = output_dir.join("metadata.json");
198        let file = File::create(metadata_path)?;
199        serde_json::to_writer_pretty(file, &metadata)?;
200
201        // Write Python loader script
202        self.write_loader_script(output_dir)?;
203
204        Ok(metadata)
205    }
206
207    /// Exports edge index as [2, num_edges] array.
208    fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
209        let (sources, targets) = graph.edge_index();
210
211        // Create node ID to index mapping
212        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
213        node_ids.sort();
214        let id_to_idx: HashMap<_, _> = node_ids
215            .iter()
216            .enumerate()
217            .map(|(i, &id)| (id, i))
218            .collect();
219
220        // Remap edge indices
221        let sources_remapped: Vec<i64> = sources
222            .iter()
223            .map(|id| *id_to_idx.get(id).unwrap_or(&0) as i64)
224            .collect();
225        let targets_remapped: Vec<i64> = targets
226            .iter()
227            .map(|id| *id_to_idx.get(id).unwrap_or(&0) as i64)
228            .collect();
229
230        // Write as NPY format
231        let path = output_dir.join("edge_index.npy");
232        self.write_npy_2d_i64(&path, &[sources_remapped, targets_remapped])?;
233
234        Ok(())
235    }
236
237    /// Exports node features.
238    fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
239        let features = graph.node_features();
240        let dim = features.first().map(|f| f.len()).unwrap_or(0);
241
242        let path = output_dir.join("node_features.npy");
243        self.write_npy_2d_f64(&path, &features)?;
244
245        Ok(dim)
246    }
247
248    /// Exports edge features.
249    fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
250        let features = graph.edge_features();
251        let dim = features.first().map(|f| f.len()).unwrap_or(0);
252
253        let path = output_dir.join("edge_features.npy");
254        self.write_npy_2d_f64(&path, &features)?;
255
256        Ok(dim)
257    }
258
259    /// Exports node labels (anomaly flags).
260    fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
261        let labels: Vec<i64> = graph
262            .node_anomaly_mask()
263            .iter()
264            .map(|&b| if b { 1 } else { 0 })
265            .collect();
266
267        let path = output_dir.join("node_labels.npy");
268        self.write_npy_1d_i64(&path, &labels)?;
269
270        Ok(())
271    }
272
273    /// Exports edge labels (anomaly flags).
274    fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
275        let labels: Vec<i64> = graph
276            .edge_anomaly_mask()
277            .iter()
278            .map(|&b| if b { 1 } else { 0 })
279            .collect();
280
281        let path = output_dir.join("edge_labels.npy");
282        self.write_npy_1d_i64(&path, &labels)?;
283
284        Ok(())
285    }
286
287    /// Exports train/val/test masks.
288    fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
289        let n = graph.node_count();
290        let mut rng = SimpleRng::new(self.config.seed);
291
292        let train_size = (n as f64 * self.config.train_ratio) as usize;
293        let val_size = (n as f64 * self.config.val_ratio) as usize;
294
295        // Create shuffled indices
296        let mut indices: Vec<usize> = (0..n).collect();
297        for i in (1..n).rev() {
298            let j = (rng.next() % (i as u64 + 1)) as usize;
299            indices.swap(i, j);
300        }
301
302        // Create masks
303        let mut train_mask = vec![false; n];
304        let mut val_mask = vec![false; n];
305        let mut test_mask = vec![false; n];
306
307        for (i, &idx) in indices.iter().enumerate() {
308            if i < train_size {
309                train_mask[idx] = true;
310            } else if i < train_size + val_size {
311                val_mask[idx] = true;
312            } else {
313                test_mask[idx] = true;
314            }
315        }
316
317        // Write masks
318        self.write_npy_1d_bool(&output_dir.join("train_mask.npy"), &train_mask)?;
319        self.write_npy_1d_bool(&output_dir.join("val_mask.npy"), &val_mask)?;
320        self.write_npy_1d_bool(&output_dir.join("test_mask.npy"), &test_mask)?;
321
322        Ok(())
323    }
324
325    /// Writes a 1D array of i64 in NPY format.
326    fn write_npy_1d_i64(&self, path: &Path, data: &[i64]) -> std::io::Result<()> {
327        let file = File::create(path)?;
328        let mut writer = BufWriter::new(file);
329
330        // NPY header
331        let shape = format!("({},)", data.len());
332        self.write_npy_header(&mut writer, "<i8", &shape)?;
333
334        // Data
335        for &val in data {
336            writer.write_all(&val.to_le_bytes())?;
337        }
338
339        Ok(())
340    }
341
342    /// Writes a 1D array of bool in NPY format.
343    fn write_npy_1d_bool(&self, path: &Path, data: &[bool]) -> std::io::Result<()> {
344        let file = File::create(path)?;
345        let mut writer = BufWriter::new(file);
346
347        // NPY header
348        let shape = format!("({},)", data.len());
349        self.write_npy_header(&mut writer, "|b1", &shape)?;
350
351        // Data
352        for &val in data {
353            writer.write_all(&[if val { 1u8 } else { 0u8 }])?;
354        }
355
356        Ok(())
357    }
358
359    /// Writes a 2D array of i64 in NPY format.
360    fn write_npy_2d_i64(&self, path: &Path, data: &[Vec<i64>]) -> std::io::Result<()> {
361        let file = File::create(path)?;
362        let mut writer = BufWriter::new(file);
363
364        let rows = data.len();
365        let cols = data.first().map(|r| r.len()).unwrap_or(0);
366
367        // NPY header
368        let shape = format!("({}, {})", rows, cols);
369        self.write_npy_header(&mut writer, "<i8", &shape)?;
370
371        // Data (row-major)
372        for row in data {
373            for &val in row {
374                writer.write_all(&val.to_le_bytes())?;
375            }
376        }
377
378        Ok(())
379    }
380
381    /// Writes a 2D array of f64 in NPY format.
382    fn write_npy_2d_f64(&self, path: &Path, data: &[Vec<f64>]) -> std::io::Result<()> {
383        let file = File::create(path)?;
384        let mut writer = BufWriter::new(file);
385
386        let rows = data.len();
387        let cols = data.first().map(|r| r.len()).unwrap_or(0);
388
389        // NPY header
390        let shape = format!("({}, {})", rows, cols);
391        self.write_npy_header(&mut writer, "<f8", &shape)?;
392
393        // Data (row-major)
394        for row in data {
395            for &val in row {
396                writer.write_all(&val.to_le_bytes())?;
397            }
398            // Pad short rows with zeros
399            for _ in row.len()..cols {
400                writer.write_all(&0.0_f64.to_le_bytes())?;
401            }
402        }
403
404        Ok(())
405    }
406
407    /// Writes NPY header.
408    fn write_npy_header<W: Write>(
409        &self,
410        writer: &mut W,
411        dtype: &str,
412        shape: &str,
413    ) -> std::io::Result<()> {
414        // Magic number and version
415        writer.write_all(&[0x93])?; // \x93
416        writer.write_all(b"NUMPY")?;
417        writer.write_all(&[0x01, 0x00])?; // Version 1.0
418
419        // Header dict
420        let header = format!(
421            "{{'descr': '{}', 'fortran_order': False, 'shape': {} }}",
422            dtype, shape
423        );
424
425        // Pad header to multiple of 64 bytes (including magic, version, header_len)
426        let header_len = header.len();
427        let total_len = 10 + header_len + 1; // magic(6) + version(2) + header_len(2) + header + newline
428        let padding = (64 - (total_len % 64)) % 64;
429        let padded_len = header_len + 1 + padding;
430
431        writer.write_all(&(padded_len as u16).to_le_bytes())?;
432        writer.write_all(header.as_bytes())?;
433        for _ in 0..padding {
434            writer.write_all(b" ")?;
435        }
436        writer.write_all(b"\n")?;
437
438        Ok(())
439    }
440
441    /// Writes a Python loader script.
442    fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
443        let script = r#"#!/usr/bin/env python3
444"""
445PyTorch Geometric Data Loader
446
447Auto-generated loader for graph data exported from synth-graph.
448"""
449
450import json
451import numpy as np
452import torch
453from pathlib import Path
454
455try:
456    from torch_geometric.data import Data
457    HAS_PYG = True
458except ImportError:
459    HAS_PYG = False
460    print("Warning: torch_geometric not installed. Install with: pip install torch-geometric")
461
462
463def load_graph(data_dir: str = ".") -> "Data":
464    """Load graph data into a PyTorch Geometric Data object."""
465    data_dir = Path(data_dir)
466
467    # Load metadata
468    with open(data_dir / "metadata.json") as f:
469        metadata = json.load(f)
470
471    # Load edge index
472    edge_index = torch.from_numpy(np.load(data_dir / "edge_index.npy")).long()
473
474    # Load node features (if available)
475    x = None
476    node_features_path = data_dir / "node_features.npy"
477    if node_features_path.exists():
478        x = torch.from_numpy(np.load(node_features_path)).float()
479
480    # Load edge features (if available)
481    edge_attr = None
482    edge_features_path = data_dir / "edge_features.npy"
483    if edge_features_path.exists():
484        edge_attr = torch.from_numpy(np.load(edge_features_path)).float()
485
486    # Load node labels (if available)
487    y = None
488    node_labels_path = data_dir / "node_labels.npy"
489    if node_labels_path.exists():
490        y = torch.from_numpy(np.load(node_labels_path)).long()
491
492    # Load masks (if available)
493    train_mask = None
494    val_mask = None
495    test_mask = None
496
497    if (data_dir / "train_mask.npy").exists():
498        train_mask = torch.from_numpy(np.load(data_dir / "train_mask.npy")).bool()
499    if (data_dir / "val_mask.npy").exists():
500        val_mask = torch.from_numpy(np.load(data_dir / "val_mask.npy")).bool()
501    if (data_dir / "test_mask.npy").exists():
502        test_mask = torch.from_numpy(np.load(data_dir / "test_mask.npy")).bool()
503
504    if not HAS_PYG:
505        return {
506            "edge_index": edge_index,
507            "x": x,
508            "edge_attr": edge_attr,
509            "y": y,
510            "train_mask": train_mask,
511            "val_mask": val_mask,
512            "test_mask": test_mask,
513            "metadata": metadata,
514        }
515
516    # Create PyG Data object
517    data = Data(
518        x=x,
519        edge_index=edge_index,
520        edge_attr=edge_attr,
521        y=y,
522        train_mask=train_mask,
523        val_mask=val_mask,
524        test_mask=test_mask,
525    )
526
527    # Store metadata
528    data.metadata = metadata
529
530    return data
531
532
533def print_summary(data_dir: str = "."):
534    """Print summary of the graph data."""
535    data = load_graph(data_dir)
536
537    if isinstance(data, dict):
538        metadata = data["metadata"]
539        print(f"Graph: {metadata['name']}")
540        print(f"Nodes: {metadata['num_nodes']}")
541        print(f"Edges: {metadata['num_edges']}")
542        print(f"Node features: {data['x'].shape if data['x'] is not None else 'None'}")
543        print(f"Edge features: {data['edge_attr'].shape if data['edge_attr'] is not None else 'None'}")
544    else:
545        print(f"Graph: {data.metadata['name']}")
546        print(f"Nodes: {data.num_nodes}")
547        print(f"Edges: {data.num_edges}")
548        print(f"Node features: {data.x.shape if data.x is not None else 'None'}")
549        print(f"Edge features: {data.edge_attr.shape if data.edge_attr is not None else 'None'}")
550        if data.y is not None:
551            print(f"Anomalous nodes: {data.y.sum().item()}")
552        if data.train_mask is not None:
553            print(f"Train/Val/Test: {data.train_mask.sum()}/{data.val_mask.sum()}/{data.test_mask.sum()}")
554
555
556if __name__ == "__main__":
557    import sys
558    data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
559    print_summary(data_dir)
560"#;
561
562        let path = output_dir.join("load_graph.py");
563        let mut file = File::create(path)?;
564        file.write_all(script.as_bytes())?;
565
566        Ok(())
567    }
568}
569
570/// Simple random number generator (xorshift64).
571struct SimpleRng {
572    state: u64,
573}
574
575impl SimpleRng {
576    fn new(seed: u64) -> Self {
577        Self {
578            state: if seed == 0 { 1 } else { seed },
579        }
580    }
581
582    fn next(&mut self) -> u64 {
583        let mut x = self.state;
584        x ^= x << 13;
585        x ^= x >> 7;
586        x ^= x << 17;
587        self.state = x;
588        x
589    }
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595    use crate::models::{EdgeType, GraphEdge, GraphNode, GraphType, NodeType};
596    use tempfile::tempdir;
597
598    fn create_test_graph() -> Graph {
599        let mut graph = Graph::new("test", GraphType::Transaction);
600
601        let n1 = graph.add_node(
602            GraphNode::new(0, NodeType::Account, "1000".to_string(), "Cash".to_string())
603                .with_feature(0.5),
604        );
605        let n2 = graph.add_node(
606            GraphNode::new(0, NodeType::Account, "2000".to_string(), "AP".to_string())
607                .with_feature(0.8),
608        );
609
610        graph.add_edge(
611            GraphEdge::new(0, n1, n2, EdgeType::Transaction)
612                .with_weight(1000.0)
613                .with_feature(6.9),
614        );
615
616        graph.compute_statistics();
617        graph
618    }
619
620    #[test]
621    fn test_pyg_export() {
622        let graph = create_test_graph();
623        let dir = tempdir().unwrap();
624
625        let exporter = PyGExporter::new(PyGExportConfig::default());
626        let metadata = exporter.export(&graph, dir.path()).unwrap();
627
628        assert_eq!(metadata.num_nodes, 2);
629        assert_eq!(metadata.num_edges, 1);
630        assert!(dir.path().join("edge_index.npy").exists());
631        assert!(dir.path().join("node_features.npy").exists());
632        assert!(dir.path().join("metadata.json").exists());
633        assert!(dir.path().join("load_graph.py").exists());
634    }
635}