Skip to main content

datasynth_graph/exporters/
pytorch_geometric.rs

1//! PyTorch Geometric exporter.
2//!
3//! Exports graph data in formats compatible with PyTorch Geometric:
4//! - NumPy arrays (.npy) for easy Python loading
5//! - JSON metadata for graph information
6//!
7//! The exported data can be loaded in Python with:
8//! ```python
9//! import numpy as np
10//! import torch
11//! from torch_geometric.data import Data
12//!
13//! node_features = torch.from_numpy(np.load('node_features.npy'))
14//! edge_index = torch.from_numpy(np.load('edge_index.npy'))
15//! edge_attr = torch.from_numpy(np.load('edge_attr.npy'))
16//! y = torch.from_numpy(np.load('labels.npy'))
17//!
18//! data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr, y=y)
19//! ```
20
21use std::collections::HashMap;
22use std::fs::{self, File};
23use std::io::Write;
24use std::path::Path;
25
26use crate::exporters::common::{CommonExportConfig, CommonGraphMetadata};
27use crate::exporters::npy_writer;
28use crate::models::Graph;
29
30/// Configuration for PyTorch Geometric export.
31#[derive(Debug, Clone, Default)]
32pub struct PyGExportConfig {
33    /// Common export settings (features, labels, masks, splits, seed).
34    pub common: CommonExportConfig,
35    /// Export categorical features as one-hot.
36    pub one_hot_categoricals: bool,
37}
38
39/// Metadata about the exported PyG data.
40pub type PyGMetadata = CommonGraphMetadata;
41
42/// PyTorch Geometric exporter.
43pub struct PyGExporter {
44    config: PyGExportConfig,
45}
46
47impl PyGExporter {
48    /// Creates a new PyG exporter.
49    pub fn new(config: PyGExportConfig) -> Self {
50        Self { config }
51    }
52
53    /// Exports a graph to PyTorch Geometric format.
54    pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<PyGMetadata> {
55        fs::create_dir_all(output_dir)?;
56
57        let mut files = Vec::new();
58        let mut statistics = HashMap::new();
59
60        // Export edge index
61        self.export_edge_index(graph, output_dir)?;
62        files.push("edge_index.npy".to_string());
63
64        // Export node features
65        if self.config.common.export_node_features {
66            let dim = self.export_node_features(graph, output_dir)?;
67            files.push("node_features.npy".to_string());
68            statistics.insert("node_feature_dim".to_string(), dim as f64);
69        }
70
71        // Export edge features
72        if self.config.common.export_edge_features {
73            let dim = self.export_edge_features(graph, output_dir)?;
74            files.push("edge_features.npy".to_string());
75            statistics.insert("edge_feature_dim".to_string(), dim as f64);
76        }
77
78        // Export node labels
79        if self.config.common.export_node_labels {
80            self.export_node_labels(graph, output_dir)?;
81            files.push("node_labels.npy".to_string());
82        }
83
84        // Export edge labels
85        if self.config.common.export_edge_labels {
86            self.export_edge_labels(graph, output_dir)?;
87            files.push("edge_labels.npy".to_string());
88        }
89
90        // Export masks
91        if self.config.common.export_masks {
92            self.export_masks(graph, output_dir)?;
93            files.push("train_mask.npy".to_string());
94            files.push("val_mask.npy".to_string());
95            files.push("test_mask.npy".to_string());
96        }
97
98        // Compute node/edge type mappings
99        let node_types: HashMap<String, usize> = graph
100            .nodes_by_type
101            .keys()
102            .enumerate()
103            .map(|(i, t)| (t.as_str().to_string(), i))
104            .collect();
105
106        let edge_types: HashMap<String, usize> = graph
107            .edges_by_type
108            .keys()
109            .enumerate()
110            .map(|(i, t)| (t.as_str().to_string(), i))
111            .collect();
112
113        // Compute statistics
114        statistics.insert("density".to_string(), graph.metadata.density);
115        statistics.insert(
116            "anomalous_node_ratio".to_string(),
117            graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
118        );
119        statistics.insert(
120            "anomalous_edge_ratio".to_string(),
121            graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
122        );
123
124        // Create metadata
125        let metadata = PyGMetadata {
126            name: graph.name.clone(),
127            num_nodes: graph.node_count(),
128            num_edges: graph.edge_count(),
129            node_feature_dim: graph.metadata.node_feature_dim,
130            edge_feature_dim: graph.metadata.edge_feature_dim,
131            num_node_classes: 2, // Normal/Anomaly
132            num_edge_classes: 2,
133            node_types,
134            edge_types,
135            is_directed: true,
136            files,
137            statistics,
138        };
139
140        // Write metadata
141        let metadata_path = output_dir.join("metadata.json");
142        let file = File::create(metadata_path)?;
143        serde_json::to_writer_pretty(file, &metadata)?;
144
145        // Write Python loader script
146        self.write_loader_script(output_dir)?;
147
148        Ok(metadata)
149    }
150
151    /// Exports edge index as [2, num_edges] array.
152    fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
153        let (sources, targets) = graph.edge_index();
154
155        // Create node ID to index mapping
156        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
157        node_ids.sort();
158        let id_to_idx: HashMap<_, _> = node_ids
159            .iter()
160            .enumerate()
161            .map(|(i, &id)| (id, i))
162            .collect();
163
164        // Remap edge indices, skipping edges with missing node IDs
165        let mut sources_remapped: Vec<i64> = Vec::with_capacity(sources.len());
166        let mut targets_remapped: Vec<i64> = Vec::with_capacity(targets.len());
167        let mut skipped_edges = 0usize;
168
169        for (src, dst) in sources.iter().zip(targets.iter()) {
170            match (id_to_idx.get(src), id_to_idx.get(dst)) {
171                (Some(&s), Some(&d)) => {
172                    sources_remapped.push(s as i64);
173                    targets_remapped.push(d as i64);
174                }
175                _ => {
176                    skipped_edges += 1;
177                }
178            }
179        }
180        if skipped_edges > 0 {
181            tracing::warn!(
182                "PyTorch Geometric export: skipped {} edges with missing node IDs",
183                skipped_edges
184            );
185        }
186
187        // Write as NPY format
188        let path = output_dir.join("edge_index.npy");
189        npy_writer::write_npy_2d_i64(&path, &[sources_remapped, targets_remapped])?;
190
191        Ok(())
192    }
193
194    /// Exports node features.
195    fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
196        let features = graph.node_features();
197        let dim = features.first().map(|f| f.len()).unwrap_or(0);
198
199        let path = output_dir.join("node_features.npy");
200        npy_writer::write_npy_2d_f64(&path, &features)?;
201
202        Ok(dim)
203    }
204
205    /// Exports edge features.
206    fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
207        let features = graph.edge_features();
208        let dim = features.first().map(|f| f.len()).unwrap_or(0);
209
210        let path = output_dir.join("edge_features.npy");
211        npy_writer::write_npy_2d_f64(&path, &features)?;
212
213        Ok(dim)
214    }
215
216    /// Exports node labels (anomaly flags).
217    fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
218        let labels: Vec<i64> = graph
219            .node_anomaly_mask()
220            .iter()
221            .map(|&b| if b { 1 } else { 0 })
222            .collect();
223
224        let path = output_dir.join("node_labels.npy");
225        npy_writer::write_npy_1d_i64(&path, &labels)?;
226
227        Ok(())
228    }
229
230    /// Exports edge labels (anomaly flags).
231    fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
232        let labels: Vec<i64> = graph
233            .edge_anomaly_mask()
234            .iter()
235            .map(|&b| if b { 1 } else { 0 })
236            .collect();
237
238        let path = output_dir.join("edge_labels.npy");
239        npy_writer::write_npy_1d_i64(&path, &labels)?;
240
241        Ok(())
242    }
243
244    /// Exports train/val/test masks.
245    fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
246        npy_writer::export_masks(
247            output_dir,
248            graph.node_count(),
249            self.config.common.seed,
250            self.config.common.train_ratio,
251            self.config.common.val_ratio,
252        )
253    }
254
255    /// Writes a Python loader script.
256    fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
257        let script = r#"#!/usr/bin/env python3
258"""
259PyTorch Geometric Data Loader
260
261Auto-generated loader for graph data exported from synth-graph.
262"""
263
264import json
265import numpy as np
266import torch
267from pathlib import Path
268
269try:
270    from torch_geometric.data import Data
271    HAS_PYG = True
272except ImportError:
273    HAS_PYG = False
274    print("Warning: torch_geometric not installed. Install with: pip install torch-geometric")
275
276
277def load_graph(data_dir: str = ".") -> "Data":
278    """Load graph data into a PyTorch Geometric Data object."""
279    data_dir = Path(data_dir)
280
281    # Load metadata
282    with open(data_dir / "metadata.json") as f:
283        metadata = json.load(f)
284
285    # Load edge index
286    edge_index = torch.from_numpy(np.load(data_dir / "edge_index.npy")).long()
287
288    # Load node features (if available)
289    x = None
290    node_features_path = data_dir / "node_features.npy"
291    if node_features_path.exists():
292        x = torch.from_numpy(np.load(node_features_path)).float()
293
294    # Load edge features (if available)
295    edge_attr = None
296    edge_features_path = data_dir / "edge_features.npy"
297    if edge_features_path.exists():
298        edge_attr = torch.from_numpy(np.load(edge_features_path)).float()
299
300    # Load node labels (if available)
301    y = None
302    node_labels_path = data_dir / "node_labels.npy"
303    if node_labels_path.exists():
304        y = torch.from_numpy(np.load(node_labels_path)).long()
305
306    # Load masks (if available)
307    train_mask = None
308    val_mask = None
309    test_mask = None
310
311    if (data_dir / "train_mask.npy").exists():
312        train_mask = torch.from_numpy(np.load(data_dir / "train_mask.npy")).bool()
313    if (data_dir / "val_mask.npy").exists():
314        val_mask = torch.from_numpy(np.load(data_dir / "val_mask.npy")).bool()
315    if (data_dir / "test_mask.npy").exists():
316        test_mask = torch.from_numpy(np.load(data_dir / "test_mask.npy")).bool()
317
318    if not HAS_PYG:
319        return {
320            "edge_index": edge_index,
321            "x": x,
322            "edge_attr": edge_attr,
323            "y": y,
324            "train_mask": train_mask,
325            "val_mask": val_mask,
326            "test_mask": test_mask,
327            "metadata": metadata,
328        }
329
330    # Create PyG Data object
331    data = Data(
332        x=x,
333        edge_index=edge_index,
334        edge_attr=edge_attr,
335        y=y,
336        train_mask=train_mask,
337        val_mask=val_mask,
338        test_mask=test_mask,
339    )
340
341    # Store metadata
342    data.metadata = metadata
343
344    return data
345
346
347def print_summary(data_dir: str = "."):
348    """Print summary of the graph data."""
349    data = load_graph(data_dir)
350
351    if isinstance(data, dict):
352        metadata = data["metadata"]
353        print(f"Graph: {metadata['name']}")
354        print(f"Nodes: {metadata['num_nodes']}")
355        print(f"Edges: {metadata['num_edges']}")
356        print(f"Node features: {data['x'].shape if data['x'] is not None else 'None'}")
357        print(f"Edge features: {data['edge_attr'].shape if data['edge_attr'] is not None else 'None'}")
358    else:
359        print(f"Graph: {data.metadata['name']}")
360        print(f"Nodes: {data.num_nodes}")
361        print(f"Edges: {data.num_edges}")
362        print(f"Node features: {data.x.shape if data.x is not None else 'None'}")
363        print(f"Edge features: {data.edge_attr.shape if data.edge_attr is not None else 'None'}")
364        if data.y is not None:
365            print(f"Anomalous nodes: {data.y.sum().item()}")
366        if data.train_mask is not None:
367            print(f"Train/Val/Test: {data.train_mask.sum()}/{data.val_mask.sum()}/{data.test_mask.sum()}")
368
369
370if __name__ == "__main__":
371    import sys
372    data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
373    print_summary(data_dir)
374"#;
375
376        let path = output_dir.join("load_graph.py");
377        let mut file = File::create(path)?;
378        file.write_all(script.as_bytes())?;
379
380        Ok(())
381    }
382}
383
384#[cfg(test)]
385#[allow(clippy::unwrap_used)]
386mod tests {
387    use super::*;
388    use crate::test_helpers::create_test_graph;
389    use tempfile::tempdir;
390
391    #[test]
392    fn test_pyg_export() {
393        let graph = create_test_graph();
394        let dir = tempdir().unwrap();
395
396        let exporter = PyGExporter::new(PyGExportConfig::default());
397        let metadata = exporter.export(&graph, dir.path()).unwrap();
398
399        assert_eq!(metadata.num_nodes, 2);
400        assert_eq!(metadata.num_edges, 1);
401        assert!(dir.path().join("edge_index.npy").exists());
402        assert!(dir.path().join("node_features.npy").exists());
403        assert!(dir.path().join("metadata.json").exists());
404        assert!(dir.path().join("load_graph.py").exists());
405    }
406}