Skip to main content

datasynth_graph/exporters/
pytorch_geometric.rs

1//! PyTorch Geometric exporter.
2//!
3//! Exports graph data in formats compatible with PyTorch Geometric:
4//! - NumPy arrays (.npy) for easy Python loading
5//! - JSON metadata for graph information
6//!
7//! The exported data can be loaded in Python with:
8//! ```python
9//! import numpy as np
10//! import torch
11//! from torch_geometric.data import Data
12//!
13//! node_features = torch.from_numpy(np.load('node_features.npy'))
14//! edge_index = torch.from_numpy(np.load('edge_index.npy'))
15//! edge_attr = torch.from_numpy(np.load('edge_attr.npy'))
16//! y = torch.from_numpy(np.load('labels.npy'))
17//!
18//! data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr, y=y)
19//! ```
20
21use std::collections::HashMap;
22use std::fs::{self, File};
23use std::io::{BufWriter, Write};
24use std::path::Path;
25
26use crate::exporters::common::{CommonExportConfig, CommonGraphMetadata};
27use crate::models::Graph;
28
29/// Configuration for PyTorch Geometric export.
30#[derive(Debug, Clone, Default)]
31pub struct PyGExportConfig {
32    /// Common export settings (features, labels, masks, splits, seed).
33    pub common: CommonExportConfig,
34    /// Export categorical features as one-hot.
35    pub one_hot_categoricals: bool,
36}
37
38/// Metadata about the exported PyG data.
39pub type PyGMetadata = CommonGraphMetadata;
40
41/// PyTorch Geometric exporter.
42pub struct PyGExporter {
43    config: PyGExportConfig,
44}
45
46impl PyGExporter {
47    /// Creates a new PyG exporter.
48    pub fn new(config: PyGExportConfig) -> Self {
49        Self { config }
50    }
51
52    /// Exports a graph to PyTorch Geometric format.
53    pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<PyGMetadata> {
54        fs::create_dir_all(output_dir)?;
55
56        let mut files = Vec::new();
57        let mut statistics = HashMap::new();
58
59        // Export edge index
60        self.export_edge_index(graph, output_dir)?;
61        files.push("edge_index.npy".to_string());
62
63        // Export node features
64        if self.config.common.export_node_features {
65            let dim = self.export_node_features(graph, output_dir)?;
66            files.push("node_features.npy".to_string());
67            statistics.insert("node_feature_dim".to_string(), dim as f64);
68        }
69
70        // Export edge features
71        if self.config.common.export_edge_features {
72            let dim = self.export_edge_features(graph, output_dir)?;
73            files.push("edge_features.npy".to_string());
74            statistics.insert("edge_feature_dim".to_string(), dim as f64);
75        }
76
77        // Export node labels
78        if self.config.common.export_node_labels {
79            self.export_node_labels(graph, output_dir)?;
80            files.push("node_labels.npy".to_string());
81        }
82
83        // Export edge labels
84        if self.config.common.export_edge_labels {
85            self.export_edge_labels(graph, output_dir)?;
86            files.push("edge_labels.npy".to_string());
87        }
88
89        // Export masks
90        if self.config.common.export_masks {
91            self.export_masks(graph, output_dir)?;
92            files.push("train_mask.npy".to_string());
93            files.push("val_mask.npy".to_string());
94            files.push("test_mask.npy".to_string());
95        }
96
97        // Compute node/edge type mappings
98        let node_types: HashMap<String, usize> = graph
99            .nodes_by_type
100            .keys()
101            .enumerate()
102            .map(|(i, t)| (t.as_str().to_string(), i))
103            .collect();
104
105        let edge_types: HashMap<String, usize> = graph
106            .edges_by_type
107            .keys()
108            .enumerate()
109            .map(|(i, t)| (t.as_str().to_string(), i))
110            .collect();
111
112        // Compute statistics
113        statistics.insert("density".to_string(), graph.metadata.density);
114        statistics.insert(
115            "anomalous_node_ratio".to_string(),
116            graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
117        );
118        statistics.insert(
119            "anomalous_edge_ratio".to_string(),
120            graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
121        );
122
123        // Create metadata
124        let metadata = PyGMetadata {
125            name: graph.name.clone(),
126            num_nodes: graph.node_count(),
127            num_edges: graph.edge_count(),
128            node_feature_dim: graph.metadata.node_feature_dim,
129            edge_feature_dim: graph.metadata.edge_feature_dim,
130            num_node_classes: 2, // Normal/Anomaly
131            num_edge_classes: 2,
132            node_types,
133            edge_types,
134            is_directed: true,
135            files,
136            statistics,
137        };
138
139        // Write metadata
140        let metadata_path = output_dir.join("metadata.json");
141        let file = File::create(metadata_path)?;
142        serde_json::to_writer_pretty(file, &metadata)?;
143
144        // Write Python loader script
145        self.write_loader_script(output_dir)?;
146
147        Ok(metadata)
148    }
149
150    /// Exports edge index as [2, num_edges] array.
151    fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
152        let (sources, targets) = graph.edge_index();
153
154        // Create node ID to index mapping
155        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
156        node_ids.sort();
157        let id_to_idx: HashMap<_, _> = node_ids
158            .iter()
159            .enumerate()
160            .map(|(i, &id)| (id, i))
161            .collect();
162
163        // Remap edge indices
164        let sources_remapped: Vec<i64> = sources
165            .iter()
166            .map(|id| *id_to_idx.get(id).unwrap_or(&0) as i64)
167            .collect();
168        let targets_remapped: Vec<i64> = targets
169            .iter()
170            .map(|id| *id_to_idx.get(id).unwrap_or(&0) as i64)
171            .collect();
172
173        // Write as NPY format
174        let path = output_dir.join("edge_index.npy");
175        self.write_npy_2d_i64(&path, &[sources_remapped, targets_remapped])?;
176
177        Ok(())
178    }
179
180    /// Exports node features.
181    fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
182        let features = graph.node_features();
183        let dim = features.first().map(|f| f.len()).unwrap_or(0);
184
185        let path = output_dir.join("node_features.npy");
186        self.write_npy_2d_f64(&path, &features)?;
187
188        Ok(dim)
189    }
190
191    /// Exports edge features.
192    fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
193        let features = graph.edge_features();
194        let dim = features.first().map(|f| f.len()).unwrap_or(0);
195
196        let path = output_dir.join("edge_features.npy");
197        self.write_npy_2d_f64(&path, &features)?;
198
199        Ok(dim)
200    }
201
202    /// Exports node labels (anomaly flags).
203    fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
204        let labels: Vec<i64> = graph
205            .node_anomaly_mask()
206            .iter()
207            .map(|&b| if b { 1 } else { 0 })
208            .collect();
209
210        let path = output_dir.join("node_labels.npy");
211        self.write_npy_1d_i64(&path, &labels)?;
212
213        Ok(())
214    }
215
216    /// Exports edge labels (anomaly flags).
217    fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
218        let labels: Vec<i64> = graph
219            .edge_anomaly_mask()
220            .iter()
221            .map(|&b| if b { 1 } else { 0 })
222            .collect();
223
224        let path = output_dir.join("edge_labels.npy");
225        self.write_npy_1d_i64(&path, &labels)?;
226
227        Ok(())
228    }
229
230    /// Exports train/val/test masks.
231    fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
232        let n = graph.node_count();
233        let mut rng = SimpleRng::new(self.config.common.seed);
234
235        let train_size = (n as f64 * self.config.common.train_ratio) as usize;
236        let val_size = (n as f64 * self.config.common.val_ratio) as usize;
237
238        // Create shuffled indices
239        let mut indices: Vec<usize> = (0..n).collect();
240        for i in (1..n).rev() {
241            let j = (rng.next() % (i as u64 + 1)) as usize;
242            indices.swap(i, j);
243        }
244
245        // Create masks
246        let mut train_mask = vec![false; n];
247        let mut val_mask = vec![false; n];
248        let mut test_mask = vec![false; n];
249
250        for (i, &idx) in indices.iter().enumerate() {
251            if i < train_size {
252                train_mask[idx] = true;
253            } else if i < train_size + val_size {
254                val_mask[idx] = true;
255            } else {
256                test_mask[idx] = true;
257            }
258        }
259
260        // Write masks
261        self.write_npy_1d_bool(&output_dir.join("train_mask.npy"), &train_mask)?;
262        self.write_npy_1d_bool(&output_dir.join("val_mask.npy"), &val_mask)?;
263        self.write_npy_1d_bool(&output_dir.join("test_mask.npy"), &test_mask)?;
264
265        Ok(())
266    }
267
268    /// Writes a 1D array of i64 in NPY format.
269    fn write_npy_1d_i64(&self, path: &Path, data: &[i64]) -> std::io::Result<()> {
270        let file = File::create(path)?;
271        let mut writer = BufWriter::new(file);
272
273        // NPY header
274        let shape = format!("({},)", data.len());
275        self.write_npy_header(&mut writer, "<i8", &shape)?;
276
277        // Data
278        for &val in data {
279            writer.write_all(&val.to_le_bytes())?;
280        }
281
282        Ok(())
283    }
284
285    /// Writes a 1D array of bool in NPY format.
286    fn write_npy_1d_bool(&self, path: &Path, data: &[bool]) -> std::io::Result<()> {
287        let file = File::create(path)?;
288        let mut writer = BufWriter::new(file);
289
290        // NPY header
291        let shape = format!("({},)", data.len());
292        self.write_npy_header(&mut writer, "|b1", &shape)?;
293
294        // Data
295        for &val in data {
296            writer.write_all(&[if val { 1u8 } else { 0u8 }])?;
297        }
298
299        Ok(())
300    }
301
302    /// Writes a 2D array of i64 in NPY format.
303    fn write_npy_2d_i64(&self, path: &Path, data: &[Vec<i64>]) -> std::io::Result<()> {
304        let file = File::create(path)?;
305        let mut writer = BufWriter::new(file);
306
307        let rows = data.len();
308        let cols = data.first().map(|r| r.len()).unwrap_or(0);
309
310        // NPY header
311        let shape = format!("({}, {})", rows, cols);
312        self.write_npy_header(&mut writer, "<i8", &shape)?;
313
314        // Data (row-major)
315        for row in data {
316            for &val in row {
317                writer.write_all(&val.to_le_bytes())?;
318            }
319        }
320
321        Ok(())
322    }
323
324    /// Writes a 2D array of f64 in NPY format.
325    fn write_npy_2d_f64(&self, path: &Path, data: &[Vec<f64>]) -> std::io::Result<()> {
326        let file = File::create(path)?;
327        let mut writer = BufWriter::new(file);
328
329        let rows = data.len();
330        let cols = data.first().map(|r| r.len()).unwrap_or(0);
331
332        // NPY header
333        let shape = format!("({}, {})", rows, cols);
334        self.write_npy_header(&mut writer, "<f8", &shape)?;
335
336        // Data (row-major)
337        for row in data {
338            for &val in row {
339                writer.write_all(&val.to_le_bytes())?;
340            }
341            // Pad short rows with zeros
342            for _ in row.len()..cols {
343                writer.write_all(&0.0_f64.to_le_bytes())?;
344            }
345        }
346
347        Ok(())
348    }
349
350    /// Writes NPY header.
351    fn write_npy_header<W: Write>(
352        &self,
353        writer: &mut W,
354        dtype: &str,
355        shape: &str,
356    ) -> std::io::Result<()> {
357        // Magic number and version
358        writer.write_all(&[0x93])?; // \x93
359        writer.write_all(b"NUMPY")?;
360        writer.write_all(&[0x01, 0x00])?; // Version 1.0
361
362        // Header dict
363        let header = format!(
364            "{{'descr': '{}', 'fortran_order': False, 'shape': {} }}",
365            dtype, shape
366        );
367
368        // Pad header to multiple of 64 bytes (including magic, version, header_len)
369        let header_len = header.len();
370        let total_len = 10 + header_len + 1; // magic(6) + version(2) + header_len(2) + header + newline
371        let padding = (64 - (total_len % 64)) % 64;
372        let padded_len = header_len + 1 + padding;
373
374        writer.write_all(&(padded_len as u16).to_le_bytes())?;
375        writer.write_all(header.as_bytes())?;
376        for _ in 0..padding {
377            writer.write_all(b" ")?;
378        }
379        writer.write_all(b"\n")?;
380
381        Ok(())
382    }
383
384    /// Writes a Python loader script.
385    fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
386        let script = r#"#!/usr/bin/env python3
387"""
388PyTorch Geometric Data Loader
389
390Auto-generated loader for graph data exported from synth-graph.
391"""
392
393import json
394import numpy as np
395import torch
396from pathlib import Path
397
398try:
399    from torch_geometric.data import Data
400    HAS_PYG = True
401except ImportError:
402    HAS_PYG = False
403    print("Warning: torch_geometric not installed. Install with: pip install torch-geometric")
404
405
406def load_graph(data_dir: str = ".") -> "Data":
407    """Load graph data into a PyTorch Geometric Data object."""
408    data_dir = Path(data_dir)
409
410    # Load metadata
411    with open(data_dir / "metadata.json") as f:
412        metadata = json.load(f)
413
414    # Load edge index
415    edge_index = torch.from_numpy(np.load(data_dir / "edge_index.npy")).long()
416
417    # Load node features (if available)
418    x = None
419    node_features_path = data_dir / "node_features.npy"
420    if node_features_path.exists():
421        x = torch.from_numpy(np.load(node_features_path)).float()
422
423    # Load edge features (if available)
424    edge_attr = None
425    edge_features_path = data_dir / "edge_features.npy"
426    if edge_features_path.exists():
427        edge_attr = torch.from_numpy(np.load(edge_features_path)).float()
428
429    # Load node labels (if available)
430    y = None
431    node_labels_path = data_dir / "node_labels.npy"
432    if node_labels_path.exists():
433        y = torch.from_numpy(np.load(node_labels_path)).long()
434
435    # Load masks (if available)
436    train_mask = None
437    val_mask = None
438    test_mask = None
439
440    if (data_dir / "train_mask.npy").exists():
441        train_mask = torch.from_numpy(np.load(data_dir / "train_mask.npy")).bool()
442    if (data_dir / "val_mask.npy").exists():
443        val_mask = torch.from_numpy(np.load(data_dir / "val_mask.npy")).bool()
444    if (data_dir / "test_mask.npy").exists():
445        test_mask = torch.from_numpy(np.load(data_dir / "test_mask.npy")).bool()
446
447    if not HAS_PYG:
448        return {
449            "edge_index": edge_index,
450            "x": x,
451            "edge_attr": edge_attr,
452            "y": y,
453            "train_mask": train_mask,
454            "val_mask": val_mask,
455            "test_mask": test_mask,
456            "metadata": metadata,
457        }
458
459    # Create PyG Data object
460    data = Data(
461        x=x,
462        edge_index=edge_index,
463        edge_attr=edge_attr,
464        y=y,
465        train_mask=train_mask,
466        val_mask=val_mask,
467        test_mask=test_mask,
468    )
469
470    # Store metadata
471    data.metadata = metadata
472
473    return data
474
475
476def print_summary(data_dir: str = "."):
477    """Print summary of the graph data."""
478    data = load_graph(data_dir)
479
480    if isinstance(data, dict):
481        metadata = data["metadata"]
482        print(f"Graph: {metadata['name']}")
483        print(f"Nodes: {metadata['num_nodes']}")
484        print(f"Edges: {metadata['num_edges']}")
485        print(f"Node features: {data['x'].shape if data['x'] is not None else 'None'}")
486        print(f"Edge features: {data['edge_attr'].shape if data['edge_attr'] is not None else 'None'}")
487    else:
488        print(f"Graph: {data.metadata['name']}")
489        print(f"Nodes: {data.num_nodes}")
490        print(f"Edges: {data.num_edges}")
491        print(f"Node features: {data.x.shape if data.x is not None else 'None'}")
492        print(f"Edge features: {data.edge_attr.shape if data.edge_attr is not None else 'None'}")
493        if data.y is not None:
494            print(f"Anomalous nodes: {data.y.sum().item()}")
495        if data.train_mask is not None:
496            print(f"Train/Val/Test: {data.train_mask.sum()}/{data.val_mask.sum()}/{data.test_mask.sum()}")
497
498
499if __name__ == "__main__":
500    import sys
501    data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
502    print_summary(data_dir)
503"#;
504
505        let path = output_dir.join("load_graph.py");
506        let mut file = File::create(path)?;
507        file.write_all(script.as_bytes())?;
508
509        Ok(())
510    }
511}
512
513/// Simple random number generator (xorshift64).
514struct SimpleRng {
515    state: u64,
516}
517
518impl SimpleRng {
519    fn new(seed: u64) -> Self {
520        Self {
521            state: if seed == 0 { 1 } else { seed },
522        }
523    }
524
525    fn next(&mut self) -> u64 {
526        let mut x = self.state;
527        x ^= x << 13;
528        x ^= x >> 7;
529        x ^= x << 17;
530        self.state = x;
531        x
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538    use crate::test_helpers::create_test_graph;
539    use tempfile::tempdir;
540
541    #[test]
542    fn test_pyg_export() {
543        let graph = create_test_graph();
544        let dir = tempdir().unwrap();
545
546        let exporter = PyGExporter::new(PyGExportConfig::default());
547        let metadata = exporter.export(&graph, dir.path()).unwrap();
548
549        assert_eq!(metadata.num_nodes, 2);
550        assert_eq!(metadata.num_edges, 1);
551        assert!(dir.path().join("edge_index.npy").exists());
552        assert!(dir.path().join("node_features.npy").exists());
553        assert!(dir.path().join("metadata.json").exists());
554        assert!(dir.path().join("load_graph.py").exists());
555    }
556}