Skip to main content

datasynth_graph/exporters/
pytorch_geometric.rs

1//! PyTorch Geometric exporter.
2//!
3//! Exports graph data in formats compatible with PyTorch Geometric:
4//! - NumPy arrays (.npy) for easy Python loading
5//! - JSON metadata for graph information
6//!
7//! The exported data can be loaded in Python with:
8//! ```python
9//! import numpy as np
10//! import torch
11//! from torch_geometric.data import Data
12//!
13//! node_features = torch.from_numpy(np.load('node_features.npy'))
14//! edge_index = torch.from_numpy(np.load('edge_index.npy'))
15//! edge_attr = torch.from_numpy(np.load('edge_attr.npy'))
16//! y = torch.from_numpy(np.load('labels.npy'))
17//!
18//! data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr, y=y)
19//! ```
20
21use std::collections::HashMap;
22use std::fs::{self, File};
23use std::io::Write;
24use std::path::Path;
25
26use crate::exporters::common::{CommonExportConfig, CommonGraphMetadata};
27use crate::exporters::npy_writer;
28use crate::models::Graph;
29
30/// Configuration for PyTorch Geometric export.
31#[derive(Debug, Clone, Default)]
32pub struct PyGExportConfig {
33    /// Common export settings (features, labels, masks, splits, seed).
34    pub common: CommonExportConfig,
35    /// Export categorical features as one-hot.
36    pub one_hot_categoricals: bool,
37}
38
39/// Metadata about the exported PyG data.
40pub type PyGMetadata = CommonGraphMetadata;
41
42/// PyTorch Geometric exporter.
43pub struct PyGExporter {
44    config: PyGExportConfig,
45}
46
47impl PyGExporter {
48    /// Creates a new PyG exporter.
49    pub fn new(config: PyGExportConfig) -> Self {
50        Self { config }
51    }
52
53    /// Exports a graph to PyTorch Geometric format.
54    pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<PyGMetadata> {
55        fs::create_dir_all(output_dir)?;
56
57        let mut files = Vec::new();
58        let mut statistics = HashMap::new();
59
60        // Export edge index
61        self.export_edge_index(graph, output_dir)?;
62        files.push("edge_index.npy".to_string());
63
64        // Export node features
65        if self.config.common.export_node_features {
66            let dim = self.export_node_features(graph, output_dir)?;
67            files.push("node_features.npy".to_string());
68            statistics.insert("node_feature_dim".to_string(), dim as f64);
69        }
70
71        // Export edge features
72        if self.config.common.export_edge_features {
73            let dim = self.export_edge_features(graph, output_dir)?;
74            files.push("edge_features.npy".to_string());
75            statistics.insert("edge_feature_dim".to_string(), dim as f64);
76        }
77
78        // Export node labels
79        if self.config.common.export_node_labels {
80            self.export_node_labels(graph, output_dir)?;
81            files.push("node_labels.npy".to_string());
82        }
83
84        // Export edge labels
85        if self.config.common.export_edge_labels {
86            self.export_edge_labels(graph, output_dir)?;
87            files.push("edge_labels.npy".to_string());
88        }
89
90        // Export masks
91        if self.config.common.export_masks {
92            self.export_masks(graph, output_dir)?;
93            files.push("train_mask.npy".to_string());
94            files.push("val_mask.npy".to_string());
95            files.push("test_mask.npy".to_string());
96        }
97
98        // Compute node/edge type mappings
99        let node_types: HashMap<String, usize> = graph
100            .nodes_by_type
101            .keys()
102            .enumerate()
103            .map(|(i, t)| (t.as_str().to_string(), i))
104            .collect();
105
106        let edge_types: HashMap<String, usize> = graph
107            .edges_by_type
108            .keys()
109            .enumerate()
110            .map(|(i, t)| (t.as_str().to_string(), i))
111            .collect();
112
113        // Compute statistics
114        statistics.insert("density".to_string(), graph.metadata.density);
115        statistics.insert(
116            "anomalous_node_ratio".to_string(),
117            graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
118        );
119        statistics.insert(
120            "anomalous_edge_ratio".to_string(),
121            graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
122        );
123
124        // Create metadata
125        let metadata = PyGMetadata {
126            name: graph.name.clone(),
127            num_nodes: graph.node_count(),
128            num_edges: graph.edge_count(),
129            node_feature_dim: graph.metadata.node_feature_dim,
130            edge_feature_dim: graph.metadata.edge_feature_dim,
131            num_node_classes: 2, // Normal/Anomaly
132            num_edge_classes: 2,
133            node_types,
134            edge_types,
135            is_directed: true,
136            files,
137            statistics,
138        };
139
140        // Write metadata
141        let metadata_path = output_dir.join("metadata.json");
142        let file = File::create(metadata_path)?;
143        serde_json::to_writer_pretty(file, &metadata)?;
144
145        // Write Python loader script
146        self.write_loader_script(output_dir)?;
147
148        Ok(metadata)
149    }
150
151    /// Exports edge index as [2, num_edges] array.
152    fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
153        let (sources, targets) = graph.edge_index();
154
155        // Create node ID to index mapping
156        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
157        node_ids.sort();
158        let id_to_idx: HashMap<_, _> = node_ids
159            .iter()
160            .enumerate()
161            .map(|(i, &id)| (id, i))
162            .collect();
163
164        // Remap edge indices
165        let sources_remapped: Vec<i64> = sources
166            .iter()
167            .map(|id| *id_to_idx.get(id).unwrap_or(&0) as i64)
168            .collect();
169        let targets_remapped: Vec<i64> = targets
170            .iter()
171            .map(|id| *id_to_idx.get(id).unwrap_or(&0) as i64)
172            .collect();
173
174        // Write as NPY format
175        let path = output_dir.join("edge_index.npy");
176        npy_writer::write_npy_2d_i64(&path, &[sources_remapped, targets_remapped])?;
177
178        Ok(())
179    }
180
181    /// Exports node features.
182    fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
183        let features = graph.node_features();
184        let dim = features.first().map(|f| f.len()).unwrap_or(0);
185
186        let path = output_dir.join("node_features.npy");
187        npy_writer::write_npy_2d_f64(&path, &features)?;
188
189        Ok(dim)
190    }
191
192    /// Exports edge features.
193    fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
194        let features = graph.edge_features();
195        let dim = features.first().map(|f| f.len()).unwrap_or(0);
196
197        let path = output_dir.join("edge_features.npy");
198        npy_writer::write_npy_2d_f64(&path, &features)?;
199
200        Ok(dim)
201    }
202
203    /// Exports node labels (anomaly flags).
204    fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
205        let labels: Vec<i64> = graph
206            .node_anomaly_mask()
207            .iter()
208            .map(|&b| if b { 1 } else { 0 })
209            .collect();
210
211        let path = output_dir.join("node_labels.npy");
212        npy_writer::write_npy_1d_i64(&path, &labels)?;
213
214        Ok(())
215    }
216
217    /// Exports edge labels (anomaly flags).
218    fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
219        let labels: Vec<i64> = graph
220            .edge_anomaly_mask()
221            .iter()
222            .map(|&b| if b { 1 } else { 0 })
223            .collect();
224
225        let path = output_dir.join("edge_labels.npy");
226        npy_writer::write_npy_1d_i64(&path, &labels)?;
227
228        Ok(())
229    }
230
231    /// Exports train/val/test masks.
232    fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
233        npy_writer::export_masks(
234            output_dir,
235            graph.node_count(),
236            self.config.common.seed,
237            self.config.common.train_ratio,
238            self.config.common.val_ratio,
239        )
240    }
241
242    /// Writes a Python loader script.
243    fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
244        let script = r#"#!/usr/bin/env python3
245"""
246PyTorch Geometric Data Loader
247
248Auto-generated loader for graph data exported from synth-graph.
249"""
250
251import json
252import numpy as np
253import torch
254from pathlib import Path
255
256try:
257    from torch_geometric.data import Data
258    HAS_PYG = True
259except ImportError:
260    HAS_PYG = False
261    print("Warning: torch_geometric not installed. Install with: pip install torch-geometric")
262
263
264def load_graph(data_dir: str = ".") -> "Data":
265    """Load graph data into a PyTorch Geometric Data object."""
266    data_dir = Path(data_dir)
267
268    # Load metadata
269    with open(data_dir / "metadata.json") as f:
270        metadata = json.load(f)
271
272    # Load edge index
273    edge_index = torch.from_numpy(np.load(data_dir / "edge_index.npy")).long()
274
275    # Load node features (if available)
276    x = None
277    node_features_path = data_dir / "node_features.npy"
278    if node_features_path.exists():
279        x = torch.from_numpy(np.load(node_features_path)).float()
280
281    # Load edge features (if available)
282    edge_attr = None
283    edge_features_path = data_dir / "edge_features.npy"
284    if edge_features_path.exists():
285        edge_attr = torch.from_numpy(np.load(edge_features_path)).float()
286
287    # Load node labels (if available)
288    y = None
289    node_labels_path = data_dir / "node_labels.npy"
290    if node_labels_path.exists():
291        y = torch.from_numpy(np.load(node_labels_path)).long()
292
293    # Load masks (if available)
294    train_mask = None
295    val_mask = None
296    test_mask = None
297
298    if (data_dir / "train_mask.npy").exists():
299        train_mask = torch.from_numpy(np.load(data_dir / "train_mask.npy")).bool()
300    if (data_dir / "val_mask.npy").exists():
301        val_mask = torch.from_numpy(np.load(data_dir / "val_mask.npy")).bool()
302    if (data_dir / "test_mask.npy").exists():
303        test_mask = torch.from_numpy(np.load(data_dir / "test_mask.npy")).bool()
304
305    if not HAS_PYG:
306        return {
307            "edge_index": edge_index,
308            "x": x,
309            "edge_attr": edge_attr,
310            "y": y,
311            "train_mask": train_mask,
312            "val_mask": val_mask,
313            "test_mask": test_mask,
314            "metadata": metadata,
315        }
316
317    # Create PyG Data object
318    data = Data(
319        x=x,
320        edge_index=edge_index,
321        edge_attr=edge_attr,
322        y=y,
323        train_mask=train_mask,
324        val_mask=val_mask,
325        test_mask=test_mask,
326    )
327
328    # Store metadata
329    data.metadata = metadata
330
331    return data
332
333
334def print_summary(data_dir: str = "."):
335    """Print summary of the graph data."""
336    data = load_graph(data_dir)
337
338    if isinstance(data, dict):
339        metadata = data["metadata"]
340        print(f"Graph: {metadata['name']}")
341        print(f"Nodes: {metadata['num_nodes']}")
342        print(f"Edges: {metadata['num_edges']}")
343        print(f"Node features: {data['x'].shape if data['x'] is not None else 'None'}")
344        print(f"Edge features: {data['edge_attr'].shape if data['edge_attr'] is not None else 'None'}")
345    else:
346        print(f"Graph: {data.metadata['name']}")
347        print(f"Nodes: {data.num_nodes}")
348        print(f"Edges: {data.num_edges}")
349        print(f"Node features: {data.x.shape if data.x is not None else 'None'}")
350        print(f"Edge features: {data.edge_attr.shape if data.edge_attr is not None else 'None'}")
351        if data.y is not None:
352            print(f"Anomalous nodes: {data.y.sum().item()}")
353        if data.train_mask is not None:
354            print(f"Train/Val/Test: {data.train_mask.sum()}/{data.val_mask.sum()}/{data.test_mask.sum()}")
355
356
357if __name__ == "__main__":
358    import sys
359    data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
360    print_summary(data_dir)
361"#;
362
363        let path = output_dir.join("load_graph.py");
364        let mut file = File::create(path)?;
365        file.write_all(script.as_bytes())?;
366
367        Ok(())
368    }
369}
370
371#[cfg(test)]
372#[allow(clippy::unwrap_used)]
373mod tests {
374    use super::*;
375    use crate::test_helpers::create_test_graph;
376    use tempfile::tempdir;
377
378    #[test]
379    fn test_pyg_export() {
380        let graph = create_test_graph();
381        let dir = tempdir().unwrap();
382
383        let exporter = PyGExporter::new(PyGExportConfig::default());
384        let metadata = exporter.export(&graph, dir.path()).unwrap();
385
386        assert_eq!(metadata.num_nodes, 2);
387        assert_eq!(metadata.num_edges, 1);
388        assert!(dir.path().join("edge_index.npy").exists());
389        assert!(dir.path().join("node_features.npy").exists());
390        assert!(dir.path().join("metadata.json").exists());
391        assert!(dir.path().join("load_graph.py").exists());
392    }
393}