Skip to main content

datasynth_graph/exporters/
dgl.rs

1//! Deep Graph Library (DGL) exporter.
2//!
3//! Exports graph data in formats compatible with DGL:
4//! - NumPy arrays (.npy) for node/edge features and labels
5//! - COO format edge index [num_edges, 2] (differs from PyG's [2, num_edges])
6//! - JSON metadata for graph information
7//!
8//! The exported data can be loaded in Python with:
9//! ```python
10//! import numpy as np
11//! import torch
12//! import dgl
13//!
14//! node_features = torch.from_numpy(np.load('node_features.npy'))
15//! edge_index = np.load('edge_index.npy')  # [num_edges, 2] COO format
16//! src, dst = edge_index[:, 0], edge_index[:, 1]
17//!
18//! g = dgl.graph((src, dst))
19//! g.ndata['feat'] = node_features
20//! ```
21//!
22//! For heterogeneous graphs, DGL uses separate arrays per node/edge type.
23
24use std::collections::HashMap;
25use std::fs::{self, File};
26use std::io::{BufWriter, Write};
27use std::path::Path;
28
29use serde::{Deserialize, Serialize};
30
31use crate::exporters::common::{CommonExportConfig, CommonGraphMetadata};
32use crate::models::Graph;
33
34/// Configuration for DGL export.
35#[derive(Debug, Clone)]
36pub struct DGLExportConfig {
37    /// Common export settings (features, labels, masks, splits, seed).
38    pub common: CommonExportConfig,
39    /// Export as heterogeneous graph (separate files per type).
40    pub heterogeneous: bool,
41    /// Include Python pickle helper script.
42    pub include_pickle_script: bool,
43}
44
45impl Default for DGLExportConfig {
46    fn default() -> Self {
47        Self {
48            common: CommonExportConfig::default(),
49            heterogeneous: false,
50            include_pickle_script: true,
51        }
52    }
53}
54
55/// Metadata about the exported DGL data.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct DGLMetadata {
58    /// Common graph metadata fields.
59    #[serde(flatten)]
60    pub common: CommonGraphMetadata,
61    /// Whether export is heterogeneous.
62    pub is_heterogeneous: bool,
63    /// Edge index format (COO).
64    pub edge_format: String,
65}
66
67/// DGL graph exporter.
68pub struct DGLExporter {
69    config: DGLExportConfig,
70}
71
72impl DGLExporter {
73    /// Creates a new DGL exporter.
74    pub fn new(config: DGLExportConfig) -> Self {
75        Self { config }
76    }
77
78    /// Exports a graph to DGL format.
79    pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<DGLMetadata> {
80        fs::create_dir_all(output_dir)?;
81
82        let mut files = Vec::new();
83        let mut statistics = HashMap::new();
84
85        // Export edge index in COO format [num_edges, 2]
86        self.export_edge_index(graph, output_dir)?;
87        files.push("edge_index.npy".to_string());
88
89        // Export node features
90        if self.config.common.export_node_features {
91            let dim = self.export_node_features(graph, output_dir)?;
92            files.push("node_features.npy".to_string());
93            statistics.insert("node_feature_dim".to_string(), dim as f64);
94        }
95
96        // Export edge features
97        if self.config.common.export_edge_features {
98            let dim = self.export_edge_features(graph, output_dir)?;
99            files.push("edge_features.npy".to_string());
100            statistics.insert("edge_feature_dim".to_string(), dim as f64);
101        }
102
103        // Export node labels
104        if self.config.common.export_node_labels {
105            self.export_node_labels(graph, output_dir)?;
106            files.push("node_labels.npy".to_string());
107        }
108
109        // Export edge labels
110        if self.config.common.export_edge_labels {
111            self.export_edge_labels(graph, output_dir)?;
112            files.push("edge_labels.npy".to_string());
113        }
114
115        // Export masks
116        if self.config.common.export_masks {
117            self.export_masks(graph, output_dir)?;
118            files.push("train_mask.npy".to_string());
119            files.push("val_mask.npy".to_string());
120            files.push("test_mask.npy".to_string());
121        }
122
123        // Export node type indices (for heterogeneous support)
124        if self.config.heterogeneous {
125            self.export_node_types(graph, output_dir)?;
126            files.push("node_type_indices.npy".to_string());
127            self.export_edge_types(graph, output_dir)?;
128            files.push("edge_type_indices.npy".to_string());
129        }
130
131        // Compute node/edge type mappings with counts
132        let node_types: HashMap<String, usize> = graph
133            .nodes_by_type
134            .iter()
135            .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
136            .collect();
137
138        let edge_types: HashMap<String, usize> = graph
139            .edges_by_type
140            .iter()
141            .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
142            .collect();
143
144        // Compute statistics
145        statistics.insert("density".to_string(), graph.metadata.density);
146        statistics.insert(
147            "anomalous_node_ratio".to_string(),
148            graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
149        );
150        statistics.insert(
151            "anomalous_edge_ratio".to_string(),
152            graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
153        );
154
155        // Create metadata
156        let metadata = DGLMetadata {
157            common: CommonGraphMetadata {
158                name: graph.name.clone(),
159                num_nodes: graph.node_count(),
160                num_edges: graph.edge_count(),
161                node_feature_dim: graph.metadata.node_feature_dim,
162                edge_feature_dim: graph.metadata.edge_feature_dim,
163                num_node_classes: 2, // Normal/Anomaly
164                num_edge_classes: 2,
165                node_types,
166                edge_types,
167                is_directed: true,
168                files,
169                statistics,
170            },
171            is_heterogeneous: self.config.heterogeneous,
172            edge_format: "COO".to_string(),
173        };
174
175        // Write metadata
176        let metadata_path = output_dir.join("metadata.json");
177        let file = File::create(metadata_path)?;
178        serde_json::to_writer_pretty(file, &metadata)?;
179
180        // Write Python loader script
181        self.write_loader_script(output_dir)?;
182
183        // Write pickle helper script if configured
184        if self.config.include_pickle_script {
185            self.write_pickle_script(output_dir)?;
186        }
187
188        Ok(metadata)
189    }
190
191    /// Exports edge index as COO format [num_edges, 2] array.
192    fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
193        let (sources, targets) = graph.edge_index();
194
195        // Create node ID to index mapping
196        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
197        node_ids.sort();
198        let id_to_idx: HashMap<_, _> = node_ids
199            .iter()
200            .enumerate()
201            .map(|(i, &id)| (id, i))
202            .collect();
203
204        // Create COO format: [num_edges, 2] where each row is (src, dst)
205        let num_edges = sources.len();
206        let coo_data: Vec<Vec<i64>> = (0..num_edges)
207            .map(|i| {
208                let src = *id_to_idx.get(&sources[i]).unwrap_or(&0) as i64;
209                let dst = *id_to_idx.get(&targets[i]).unwrap_or(&0) as i64;
210                vec![src, dst]
211            })
212            .collect();
213
214        // Write as NPY format [num_edges, 2]
215        let path = output_dir.join("edge_index.npy");
216        self.write_npy_2d_i64(&path, &coo_data)?;
217
218        Ok(())
219    }
220
221    /// Exports node features.
222    fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
223        let features = graph.node_features();
224        let dim = features.first().map(|f| f.len()).unwrap_or(0);
225
226        let path = output_dir.join("node_features.npy");
227        self.write_npy_2d_f64(&path, &features)?;
228
229        Ok(dim)
230    }
231
232    /// Exports edge features.
233    fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
234        let features = graph.edge_features();
235        let dim = features.first().map(|f| f.len()).unwrap_or(0);
236
237        let path = output_dir.join("edge_features.npy");
238        self.write_npy_2d_f64(&path, &features)?;
239
240        Ok(dim)
241    }
242
243    /// Exports node labels (anomaly flags).
244    fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
245        let labels: Vec<i64> = graph
246            .node_anomaly_mask()
247            .iter()
248            .map(|&b| if b { 1 } else { 0 })
249            .collect();
250
251        let path = output_dir.join("node_labels.npy");
252        self.write_npy_1d_i64(&path, &labels)?;
253
254        Ok(())
255    }
256
257    /// Exports edge labels (anomaly flags).
258    fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
259        let labels: Vec<i64> = graph
260            .edge_anomaly_mask()
261            .iter()
262            .map(|&b| if b { 1 } else { 0 })
263            .collect();
264
265        let path = output_dir.join("edge_labels.npy");
266        self.write_npy_1d_i64(&path, &labels)?;
267
268        Ok(())
269    }
270
271    /// Exports train/val/test masks.
272    fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
273        let n = graph.node_count();
274        let mut rng = SimpleRng::new(self.config.common.seed);
275
276        let train_size = (n as f64 * self.config.common.train_ratio) as usize;
277        let val_size = (n as f64 * self.config.common.val_ratio) as usize;
278
279        // Create shuffled indices
280        let mut indices: Vec<usize> = (0..n).collect();
281        for i in (1..n).rev() {
282            let j = (rng.next() % (i as u64 + 1)) as usize;
283            indices.swap(i, j);
284        }
285
286        // Create masks
287        let mut train_mask = vec![false; n];
288        let mut val_mask = vec![false; n];
289        let mut test_mask = vec![false; n];
290
291        for (i, &idx) in indices.iter().enumerate() {
292            if i < train_size {
293                train_mask[idx] = true;
294            } else if i < train_size + val_size {
295                val_mask[idx] = true;
296            } else {
297                test_mask[idx] = true;
298            }
299        }
300
301        // Write masks
302        self.write_npy_1d_bool(&output_dir.join("train_mask.npy"), &train_mask)?;
303        self.write_npy_1d_bool(&output_dir.join("val_mask.npy"), &val_mask)?;
304        self.write_npy_1d_bool(&output_dir.join("test_mask.npy"), &test_mask)?;
305
306        Ok(())
307    }
308
309    /// Exports node type indices for heterogeneous graphs.
310    fn export_node_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
311        // Create type mapping
312        let type_to_idx: HashMap<_, _> = graph
313            .nodes_by_type
314            .keys()
315            .enumerate()
316            .map(|(i, t)| (t.clone(), i as i64))
317            .collect();
318
319        // Get node IDs in sorted order for consistent indexing
320        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
321        node_ids.sort();
322
323        // Map each node to its type index
324        let type_indices: Vec<i64> = node_ids
325            .iter()
326            .map(|id| {
327                let node = graph.nodes.get(id).unwrap();
328                *type_to_idx.get(&node.node_type).unwrap_or(&0)
329            })
330            .collect();
331
332        let path = output_dir.join("node_type_indices.npy");
333        self.write_npy_1d_i64(&path, &type_indices)?;
334
335        Ok(())
336    }
337
338    /// Exports edge type indices for heterogeneous graphs.
339    fn export_edge_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
340        // Create type mapping
341        let type_to_idx: HashMap<_, _> = graph
342            .edges_by_type
343            .keys()
344            .enumerate()
345            .map(|(i, t)| (t.clone(), i as i64))
346            .collect();
347
348        // Get edge IDs in sorted order for consistent indexing
349        let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
350        edge_ids.sort();
351
352        // Map each edge to its type index
353        let type_indices: Vec<i64> = edge_ids
354            .iter()
355            .map(|id| {
356                let edge = graph.edges.get(id).unwrap();
357                *type_to_idx.get(&edge.edge_type).unwrap_or(&0)
358            })
359            .collect();
360
361        let path = output_dir.join("edge_type_indices.npy");
362        self.write_npy_1d_i64(&path, &type_indices)?;
363
364        Ok(())
365    }
366
367    /// Writes a 1D array of i64 in NPY format.
368    fn write_npy_1d_i64(&self, path: &Path, data: &[i64]) -> std::io::Result<()> {
369        let file = File::create(path)?;
370        let mut writer = BufWriter::new(file);
371
372        // NPY header
373        let shape = format!("({},)", data.len());
374        self.write_npy_header(&mut writer, "<i8", &shape)?;
375
376        // Data
377        for &val in data {
378            writer.write_all(&val.to_le_bytes())?;
379        }
380
381        Ok(())
382    }
383
384    /// Writes a 1D array of bool in NPY format.
385    fn write_npy_1d_bool(&self, path: &Path, data: &[bool]) -> std::io::Result<()> {
386        let file = File::create(path)?;
387        let mut writer = BufWriter::new(file);
388
389        // NPY header
390        let shape = format!("({},)", data.len());
391        self.write_npy_header(&mut writer, "|b1", &shape)?;
392
393        // Data
394        for &val in data {
395            writer.write_all(&[if val { 1u8 } else { 0u8 }])?;
396        }
397
398        Ok(())
399    }
400
401    /// Writes a 2D array of i64 in NPY format.
402    fn write_npy_2d_i64(&self, path: &Path, data: &[Vec<i64>]) -> std::io::Result<()> {
403        let file = File::create(path)?;
404        let mut writer = BufWriter::new(file);
405
406        let rows = data.len();
407        let cols = data.first().map(|r| r.len()).unwrap_or(0);
408
409        // NPY header
410        let shape = format!("({}, {})", rows, cols);
411        self.write_npy_header(&mut writer, "<i8", &shape)?;
412
413        // Data (row-major)
414        for row in data {
415            for &val in row {
416                writer.write_all(&val.to_le_bytes())?;
417            }
418            // Pad short rows if needed
419            for _ in row.len()..cols {
420                writer.write_all(&0_i64.to_le_bytes())?;
421            }
422        }
423
424        Ok(())
425    }
426
427    /// Writes a 2D array of f64 in NPY format.
428    fn write_npy_2d_f64(&self, path: &Path, data: &[Vec<f64>]) -> std::io::Result<()> {
429        let file = File::create(path)?;
430        let mut writer = BufWriter::new(file);
431
432        let rows = data.len();
433        let cols = data.first().map(|r| r.len()).unwrap_or(0);
434
435        // NPY header
436        let shape = format!("({}, {})", rows, cols);
437        self.write_npy_header(&mut writer, "<f8", &shape)?;
438
439        // Data (row-major)
440        for row in data {
441            for &val in row {
442                writer.write_all(&val.to_le_bytes())?;
443            }
444            // Pad short rows with zeros
445            for _ in row.len()..cols {
446                writer.write_all(&0.0_f64.to_le_bytes())?;
447            }
448        }
449
450        Ok(())
451    }
452
453    /// Writes NPY header.
454    fn write_npy_header<W: Write>(
455        &self,
456        writer: &mut W,
457        dtype: &str,
458        shape: &str,
459    ) -> std::io::Result<()> {
460        // Magic number and version
461        writer.write_all(&[0x93])?; // \x93
462        writer.write_all(b"NUMPY")?;
463        writer.write_all(&[0x01, 0x00])?; // Version 1.0
464
465        // Header dict
466        let header = format!(
467            "{{'descr': '{}', 'fortran_order': False, 'shape': {} }}",
468            dtype, shape
469        );
470
471        // Pad header to multiple of 64 bytes (including magic, version, header_len)
472        let header_len = header.len();
473        let total_len = 10 + header_len + 1; // magic(6) + version(2) + header_len(2) + header + newline
474        let padding = (64 - (total_len % 64)) % 64;
475        let padded_len = header_len + 1 + padding;
476
477        writer.write_all(&(padded_len as u16).to_le_bytes())?;
478        writer.write_all(header.as_bytes())?;
479        for _ in 0..padding {
480            writer.write_all(b" ")?;
481        }
482        writer.write_all(b"\n")?;
483
484        Ok(())
485    }
486
487    /// Writes a Python loader script for DGL.
488    fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
489        let script = r#"#!/usr/bin/env python3
490"""
491DGL (Deep Graph Library) Data Loader
492
493Auto-generated loader for graph data exported from synth-graph.
494Supports both homogeneous and heterogeneous graph loading.
495"""
496
497import json
498import numpy as np
499from pathlib import Path
500
501try:
502    import torch
503    HAS_TORCH = True
504except ImportError:
505    HAS_TORCH = False
506    print("Warning: torch not installed. Install with: pip install torch")
507
508try:
509    import dgl
510    HAS_DGL = True
511except ImportError:
512    HAS_DGL = False
513    print("Warning: dgl not installed. Install with: pip install dgl")
514
515
516def load_graph(data_dir: str = ".") -> "dgl.DGLGraph":
517    """Load graph data into a DGL graph object.
518
519    Args:
520        data_dir: Directory containing the exported graph data.
521
522    Returns:
523        DGL graph with node features, edge features, and labels attached.
524    """
525    data_dir = Path(data_dir)
526
527    # Load metadata
528    with open(data_dir / "metadata.json") as f:
529        metadata = json.load(f)
530
531    # Load edge index (COO format: [num_edges, 2])
532    edge_index = np.load(data_dir / "edge_index.npy")
533    src = edge_index[:, 0]
534    dst = edge_index[:, 1]
535
536    num_nodes = metadata["num_nodes"]
537
538    if not HAS_DGL:
539        # Return dict if DGL not available
540        result = {
541            "src": src,
542            "dst": dst,
543            "num_nodes": num_nodes,
544            "metadata": metadata,
545        }
546
547        # Load optional arrays
548        if (data_dir / "node_features.npy").exists():
549            result["node_features"] = np.load(data_dir / "node_features.npy")
550        if (data_dir / "edge_features.npy").exists():
551            result["edge_features"] = np.load(data_dir / "edge_features.npy")
552        if (data_dir / "node_labels.npy").exists():
553            result["node_labels"] = np.load(data_dir / "node_labels.npy")
554        if (data_dir / "edge_labels.npy").exists():
555            result["edge_labels"] = np.load(data_dir / "edge_labels.npy")
556        if (data_dir / "train_mask.npy").exists():
557            result["train_mask"] = np.load(data_dir / "train_mask.npy")
558            result["val_mask"] = np.load(data_dir / "val_mask.npy")
559            result["test_mask"] = np.load(data_dir / "test_mask.npy")
560
561        return result
562
563    # Create DGL graph
564    g = dgl.graph((src, dst), num_nodes=num_nodes)
565
566    # Load and attach node features
567    node_features_path = data_dir / "node_features.npy"
568    if node_features_path.exists():
569        node_features = np.load(node_features_path)
570        if HAS_TORCH:
571            g.ndata['feat'] = torch.from_numpy(node_features).float()
572        else:
573            g.ndata['feat'] = node_features
574
575    # Load and attach edge features
576    edge_features_path = data_dir / "edge_features.npy"
577    if edge_features_path.exists():
578        edge_features = np.load(edge_features_path)
579        if HAS_TORCH:
580            g.edata['feat'] = torch.from_numpy(edge_features).float()
581        else:
582            g.edata['feat'] = edge_features
583
584    # Load and attach node labels
585    node_labels_path = data_dir / "node_labels.npy"
586    if node_labels_path.exists():
587        node_labels = np.load(node_labels_path)
588        if HAS_TORCH:
589            g.ndata['label'] = torch.from_numpy(node_labels).long()
590        else:
591            g.ndata['label'] = node_labels
592
593    # Load and attach edge labels
594    edge_labels_path = data_dir / "edge_labels.npy"
595    if edge_labels_path.exists():
596        edge_labels = np.load(edge_labels_path)
597        if HAS_TORCH:
598            g.edata['label'] = torch.from_numpy(edge_labels).long()
599        else:
600            g.edata['label'] = edge_labels
601
602    # Load and attach masks
603    if (data_dir / "train_mask.npy").exists():
604        train_mask = np.load(data_dir / "train_mask.npy")
605        val_mask = np.load(data_dir / "val_mask.npy")
606        test_mask = np.load(data_dir / "test_mask.npy")
607
608        if HAS_TORCH:
609            g.ndata['train_mask'] = torch.from_numpy(train_mask).bool()
610            g.ndata['val_mask'] = torch.from_numpy(val_mask).bool()
611            g.ndata['test_mask'] = torch.from_numpy(test_mask).bool()
612        else:
613            g.ndata['train_mask'] = train_mask
614            g.ndata['val_mask'] = val_mask
615            g.ndata['test_mask'] = test_mask
616
617    # Store metadata as graph attribute
618    g.metadata = metadata
619
620    return g
621
622
623def load_heterogeneous_graph(data_dir: str = ".") -> "dgl.DGLHeteroGraph":
624    """Load graph data into a DGL heterogeneous graph.
625
626    This function handles graphs with multiple node and edge types.
627
628    Args:
629        data_dir: Directory containing the exported graph data.
630
631    Returns:
632        DGL heterogeneous graph.
633    """
634    data_dir = Path(data_dir)
635
636    # Load metadata
637    with open(data_dir / "metadata.json") as f:
638        metadata = json.load(f)
639
640    if not metadata.get("is_heterogeneous", False):
641        print("Warning: Graph was not exported as heterogeneous. Using homogeneous loader.")
642        return load_graph(data_dir)
643
644    if not HAS_DGL:
645        raise ImportError("DGL is required for heterogeneous graph loading")
646
647    # Load edge index and type indices
648    edge_index = np.load(data_dir / "edge_index.npy")
649    edge_types = np.load(data_dir / "edge_type_indices.npy")
650    node_types = np.load(data_dir / "node_type_indices.npy")
651
652    # Get type names from metadata
653    node_type_names = list(metadata["node_types"].keys())
654    edge_type_names = list(metadata["edge_types"].keys())
655
656    # Build edge dict for heterogeneous graph
657    edge_dict = {}
658    for etype_idx, etype_name in enumerate(edge_type_names):
659        mask = edge_types == etype_idx
660        if mask.any():
661            src = edge_index[mask, 0]
662            dst = edge_index[mask, 1]
663            # For heterogeneous, we need to specify (src_type, edge_type, dst_type)
664            # Using simplified convention: (node_type, edge_type, node_type)
665            edge_dict[(node_type_names[0] if node_type_names else 'node',
666                      etype_name,
667                      node_type_names[0] if node_type_names else 'node')] = (src, dst)
668
669    # Create heterogeneous graph
670    g = dgl.heterograph(edge_dict) if edge_dict else dgl.graph(([], []))
671    g.metadata = metadata
672
673    return g
674
675
676def print_summary(data_dir: str = "."):
677    """Print summary of the graph data."""
678    data_dir = Path(data_dir)
679
680    with open(data_dir / "metadata.json") as f:
681        metadata = json.load(f)
682
683    print(f"Graph: {metadata['name']}")
684    print(f"Format: DGL ({metadata['edge_format']} edge format)")
685    print(f"Nodes: {metadata['num_nodes']}")
686    print(f"Edges: {metadata['num_edges']}")
687    print(f"Node feature dim: {metadata['node_feature_dim']}")
688    print(f"Edge feature dim: {metadata['edge_feature_dim']}")
689    print(f"Directed: {metadata['is_directed']}")
690    print(f"Heterogeneous: {metadata['is_heterogeneous']}")
691
692    if metadata['node_types']:
693        print(f"Node types: {metadata['node_types']}")
694    if metadata['edge_types']:
695        print(f"Edge types: {metadata['edge_types']}")
696
697    if metadata['statistics']:
698        print("\nStatistics:")
699        for key, value in metadata['statistics'].items():
700            print(f"  {key}: {value:.4f}")
701
702    if HAS_DGL:
703        print("\nLoading graph...")
704        g = load_graph(data_dir)
705        if hasattr(g, 'num_nodes'):
706            print(f"DGL graph loaded: {g.num_nodes()} nodes, {g.num_edges()} edges")
707            if 'label' in g.ndata:
708                print(f"Anomalous nodes: {g.ndata['label'].sum().item()}")
709
710
711if __name__ == "__main__":
712    import sys
713    data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
714    print_summary(data_dir)
715"#;
716
717        let path = output_dir.join("load_graph.py");
718        let mut file = File::create(path)?;
719        file.write_all(script.as_bytes())?;
720
721        Ok(())
722    }
723
724    /// Writes a helper script for saving/loading DGL graphs as pickle.
725    fn write_pickle_script(&self, output_dir: &Path) -> std::io::Result<()> {
726        let script = r#"#!/usr/bin/env python3
727"""
728DGL Graph Pickle Helper
729
730Utility to save and load DGL graphs as pickle files for faster subsequent loading.
731"""
732
733import pickle
734from pathlib import Path
735
736try:
737    import dgl
738    HAS_DGL = True
739except ImportError:
740    HAS_DGL = False
741
742
743def save_dgl_graph(graph, output_path: str):
744    """Save a DGL graph to a pickle file.
745
746    Args:
747        graph: DGL graph to save.
748        output_path: Path to save the pickle file.
749    """
750    output_path = Path(output_path)
751
752    # Save graph data
753    graph_data = {
754        'num_nodes': graph.num_nodes(),
755        'edges': graph.edges(),
756        'ndata': {k: v.numpy() if hasattr(v, 'numpy') else v
757                  for k, v in graph.ndata.items()},
758        'edata': {k: v.numpy() if hasattr(v, 'numpy') else v
759                  for k, v in graph.edata.items()},
760        'metadata': getattr(graph, 'metadata', {}),
761    }
762
763    with open(output_path, 'wb') as f:
764        pickle.dump(graph_data, f, protocol=pickle.HIGHEST_PROTOCOL)
765
766    print(f"Saved graph to {output_path}")
767
768
769def load_dgl_graph(input_path: str) -> "dgl.DGLGraph":
770    """Load a DGL graph from a pickle file.
771
772    Args:
773        input_path: Path to the pickle file.
774
775    Returns:
776        DGL graph.
777    """
778    if not HAS_DGL:
779        raise ImportError("DGL is required to load graphs")
780
781    import torch
782
783    input_path = Path(input_path)
784
785    with open(input_path, 'rb') as f:
786        graph_data = pickle.load(f)
787
788    # Recreate graph
789    src, dst = graph_data['edges']
790    g = dgl.graph((src, dst), num_nodes=graph_data['num_nodes'])
791
792    # Restore node data
793    for k, v in graph_data['ndata'].items():
794        g.ndata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
795
796    # Restore edge data
797    for k, v in graph_data['edata'].items():
798        g.edata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
799
800    # Restore metadata
801    g.metadata = graph_data.get('metadata', {})
802
803    return g
804
805
806def convert_to_pickle(data_dir: str, output_path: str = None):
807    """Convert exported graph data to pickle format for faster loading.
808
809    Args:
810        data_dir: Directory containing the exported graph data.
811        output_path: Path for output pickle file. Defaults to data_dir/graph.pkl.
812    """
813    from load_graph import load_graph
814
815    data_dir = Path(data_dir)
816    output_path = Path(output_path) if output_path else data_dir / "graph.pkl"
817
818    print(f"Loading graph from {data_dir}...")
819    g = load_graph(str(data_dir))
820
821    if isinstance(g, dict):
822        print("Error: DGL not available, cannot convert to pickle")
823        return
824
825    save_dgl_graph(g, str(output_path))
826    print(f"Graph saved to {output_path}")
827
828
829if __name__ == "__main__":
830    import sys
831
832    if len(sys.argv) < 2:
833        print("Usage:")
834        print("  python pickle_helper.py convert <data_dir> [output_path]")
835        print("  python pickle_helper.py load <pickle_path>")
836        sys.exit(1)
837
838    command = sys.argv[1]
839
840    if command == "convert":
841        data_dir = sys.argv[2] if len(sys.argv) > 2 else "."
842        output_path = sys.argv[3] if len(sys.argv) > 3 else None
843        convert_to_pickle(data_dir, output_path)
844    elif command == "load":
845        pickle_path = sys.argv[2]
846        g = load_dgl_graph(pickle_path)
847        print(f"Loaded graph: {g.num_nodes()} nodes, {g.num_edges()} edges")
848    else:
849        print(f"Unknown command: {command}")
850"#;
851
852        let path = output_dir.join("pickle_helper.py");
853        let mut file = File::create(path)?;
854        file.write_all(script.as_bytes())?;
855
856        Ok(())
857    }
858}
859
860/// Simple random number generator (xorshift64).
861struct SimpleRng {
862    state: u64,
863}
864
865impl SimpleRng {
866    fn new(seed: u64) -> Self {
867        Self {
868            state: if seed == 0 { 1 } else { seed },
869        }
870    }
871
872    fn next(&mut self) -> u64 {
873        let mut x = self.state;
874        x ^= x << 13;
875        x ^= x >> 7;
876        x ^= x << 17;
877        self.state = x;
878        x
879    }
880}
881
882#[cfg(test)]
883mod tests {
884    use super::*;
885    use crate::test_helpers::create_test_graph_with_company;
886    use tempfile::tempdir;
887
888    #[test]
889    fn test_dgl_export_basic() {
890        let graph = create_test_graph_with_company();
891        let dir = tempdir().unwrap();
892
893        let exporter = DGLExporter::new(DGLExportConfig::default());
894        let metadata = exporter.export(&graph, dir.path()).unwrap();
895
896        assert_eq!(metadata.common.num_nodes, 3);
897        assert_eq!(metadata.common.num_edges, 2);
898        assert_eq!(metadata.edge_format, "COO");
899        assert!(dir.path().join("edge_index.npy").exists());
900        assert!(dir.path().join("node_features.npy").exists());
901        assert!(dir.path().join("node_labels.npy").exists());
902        assert!(dir.path().join("metadata.json").exists());
903        assert!(dir.path().join("load_graph.py").exists());
904        assert!(dir.path().join("pickle_helper.py").exists());
905    }
906
907    #[test]
908    fn test_dgl_export_heterogeneous() {
909        let graph = create_test_graph_with_company();
910        let dir = tempdir().unwrap();
911
912        let config = DGLExportConfig {
913            heterogeneous: true,
914            ..Default::default()
915        };
916        let exporter = DGLExporter::new(config);
917        let metadata = exporter.export(&graph, dir.path()).unwrap();
918
919        assert!(metadata.is_heterogeneous);
920        assert!(dir.path().join("node_type_indices.npy").exists());
921        assert!(dir.path().join("edge_type_indices.npy").exists());
922    }
923
924    #[test]
925    fn test_dgl_export_masks() {
926        let graph = create_test_graph_with_company();
927        let dir = tempdir().unwrap();
928
929        let exporter = DGLExporter::new(DGLExportConfig::default());
930        let metadata = exporter.export(&graph, dir.path()).unwrap();
931
932        assert!(metadata
933            .common
934            .files
935            .contains(&"train_mask.npy".to_string()));
936        assert!(metadata.common.files.contains(&"val_mask.npy".to_string()));
937        assert!(metadata.common.files.contains(&"test_mask.npy".to_string()));
938        assert!(dir.path().join("train_mask.npy").exists());
939        assert!(dir.path().join("val_mask.npy").exists());
940        assert!(dir.path().join("test_mask.npy").exists());
941    }
942
943    #[test]
944    fn test_dgl_coo_format() {
945        let graph = create_test_graph_with_company();
946        let dir = tempdir().unwrap();
947
948        let exporter = DGLExporter::new(DGLExportConfig::default());
949        exporter.export(&graph, dir.path()).unwrap();
950
951        // Verify edge_index file exists and has correct format
952        // COO format should be [num_edges, 2]
953        let edge_path = dir.path().join("edge_index.npy");
954        assert!(edge_path.exists());
955
956        // The metadata confirms format
957        let metadata_path = dir.path().join("metadata.json");
958        let metadata: DGLMetadata =
959            serde_json::from_reader(File::open(metadata_path).unwrap()).unwrap();
960        assert_eq!(metadata.edge_format, "COO");
961    }
962
963    #[test]
964    fn test_dgl_export_no_masks() {
965        let graph = create_test_graph_with_company();
966        let dir = tempdir().unwrap();
967
968        let config = DGLExportConfig {
969            common: CommonExportConfig {
970                export_masks: false,
971                ..Default::default()
972            },
973            ..Default::default()
974        };
975        let exporter = DGLExporter::new(config);
976        let metadata = exporter.export(&graph, dir.path()).unwrap();
977
978        assert!(!metadata
979            .common
980            .files
981            .contains(&"train_mask.npy".to_string()));
982        assert!(!dir.path().join("train_mask.npy").exists());
983    }
984
985    #[test]
986    fn test_dgl_export_minimal() {
987        let graph = create_test_graph_with_company();
988        let dir = tempdir().unwrap();
989
990        let config = DGLExportConfig {
991            common: CommonExportConfig {
992                export_node_features: false,
993                export_edge_features: false,
994                export_node_labels: false,
995                export_edge_labels: false,
996                export_masks: false,
997                ..Default::default()
998            },
999            include_pickle_script: false,
1000            ..Default::default()
1001        };
1002        let exporter = DGLExporter::new(config);
1003        let metadata = exporter.export(&graph, dir.path()).unwrap();
1004
1005        // Only edge_index and loader script should exist
1006        assert_eq!(metadata.common.files.len(), 1); // Only edge_index.npy
1007        assert!(dir.path().join("edge_index.npy").exists());
1008        assert!(dir.path().join("load_graph.py").exists()); // Loader always generated
1009        assert!(dir.path().join("metadata.json").exists());
1010        assert!(!dir.path().join("pickle_helper.py").exists());
1011    }
1012
1013    #[test]
1014    fn test_dgl_statistics() {
1015        let graph = create_test_graph_with_company();
1016        let dir = tempdir().unwrap();
1017
1018        let exporter = DGLExporter::new(DGLExportConfig::default());
1019        let metadata = exporter.export(&graph, dir.path()).unwrap();
1020
1021        // Should have density and anomaly ratios
1022        assert!(metadata.common.statistics.contains_key("density"));
1023        assert!(metadata
1024            .common
1025            .statistics
1026            .contains_key("anomalous_node_ratio"));
1027        assert!(metadata
1028            .common
1029            .statistics
1030            .contains_key("anomalous_edge_ratio"));
1031
1032        // One of three nodes is anomalous
1033        let node_ratio = metadata
1034            .common
1035            .statistics
1036            .get("anomalous_node_ratio")
1037            .unwrap();
1038        assert!((*node_ratio - 1.0 / 3.0).abs() < 0.01);
1039    }
1040}