Skip to main content

datasynth_graph/exporters/
dgl.rs

1//! Deep Graph Library (DGL) exporter.
2//!
3//! Exports graph data in formats compatible with DGL:
4//! - NumPy arrays (.npy) for node/edge features and labels
5//! - COO format edge index [num_edges, 2] (differs from PyG's [2, num_edges])
6//! - JSON metadata for graph information
7//!
8//! The exported data can be loaded in Python with:
9//! ```python
10//! import numpy as np
11//! import torch
12//! import dgl
13//!
14//! node_features = torch.from_numpy(np.load('node_features.npy'))
15//! edge_index = np.load('edge_index.npy')  # [num_edges, 2] COO format
16//! src, dst = edge_index[:, 0], edge_index[:, 1]
17//!
18//! g = dgl.graph((src, dst))
19//! g.ndata['feat'] = node_features
20//! ```
21//!
22//! For heterogeneous graphs, DGL uses separate arrays per node/edge type.
23
24use std::collections::HashMap;
25use std::fs::{self, File};
26use std::io::{BufWriter, Write};
27use std::path::Path;
28
29use serde::{Deserialize, Serialize};
30
31use crate::models::Graph;
32
33/// Configuration for DGL export.
34#[derive(Debug, Clone)]
35pub struct DGLExportConfig {
36    /// Export node features.
37    pub export_node_features: bool,
38    /// Export edge features.
39    pub export_edge_features: bool,
40    /// Export node labels (anomaly flags).
41    pub export_node_labels: bool,
42    /// Export edge labels (anomaly flags).
43    pub export_edge_labels: bool,
44    /// Export train/val/test masks.
45    pub export_masks: bool,
46    /// Train split ratio.
47    pub train_ratio: f64,
48    /// Validation split ratio.
49    pub val_ratio: f64,
50    /// Random seed for splits.
51    pub seed: u64,
52    /// Export as heterogeneous graph (separate files per type).
53    pub heterogeneous: bool,
54    /// Include Python pickle helper script.
55    pub include_pickle_script: bool,
56}
57
58impl Default for DGLExportConfig {
59    fn default() -> Self {
60        Self {
61            export_node_features: true,
62            export_edge_features: true,
63            export_node_labels: true,
64            export_edge_labels: true,
65            export_masks: true,
66            train_ratio: 0.7,
67            val_ratio: 0.15,
68            seed: 42,
69            heterogeneous: false,
70            include_pickle_script: true,
71        }
72    }
73}
74
75/// Metadata about the exported DGL data.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct DGLMetadata {
78    /// Graph name.
79    pub name: String,
80    /// Number of nodes.
81    pub num_nodes: usize,
82    /// Number of edges.
83    pub num_edges: usize,
84    /// Node feature dimension.
85    pub node_feature_dim: usize,
86    /// Edge feature dimension.
87    pub edge_feature_dim: usize,
88    /// Number of node classes (for classification).
89    pub num_node_classes: usize,
90    /// Number of edge classes (for classification).
91    pub num_edge_classes: usize,
92    /// Node type mapping (type name -> count).
93    pub node_types: HashMap<String, usize>,
94    /// Edge type mapping (type name -> count).
95    pub edge_types: HashMap<String, usize>,
96    /// Whether graph is directed.
97    pub is_directed: bool,
98    /// Whether export is heterogeneous.
99    pub is_heterogeneous: bool,
100    /// Edge index format (COO).
101    pub edge_format: String,
102    /// Files included in export.
103    pub files: Vec<String>,
104    /// Additional statistics.
105    pub statistics: HashMap<String, f64>,
106}
107
108/// DGL graph exporter.
109pub struct DGLExporter {
110    config: DGLExportConfig,
111}
112
113impl DGLExporter {
114    /// Creates a new DGL exporter.
115    pub fn new(config: DGLExportConfig) -> Self {
116        Self { config }
117    }
118
119    /// Exports a graph to DGL format.
120    pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<DGLMetadata> {
121        fs::create_dir_all(output_dir)?;
122
123        let mut files = Vec::new();
124        let mut statistics = HashMap::new();
125
126        // Export edge index in COO format [num_edges, 2]
127        self.export_edge_index(graph, output_dir)?;
128        files.push("edge_index.npy".to_string());
129
130        // Export node features
131        if self.config.export_node_features {
132            let dim = self.export_node_features(graph, output_dir)?;
133            files.push("node_features.npy".to_string());
134            statistics.insert("node_feature_dim".to_string(), dim as f64);
135        }
136
137        // Export edge features
138        if self.config.export_edge_features {
139            let dim = self.export_edge_features(graph, output_dir)?;
140            files.push("edge_features.npy".to_string());
141            statistics.insert("edge_feature_dim".to_string(), dim as f64);
142        }
143
144        // Export node labels
145        if self.config.export_node_labels {
146            self.export_node_labels(graph, output_dir)?;
147            files.push("node_labels.npy".to_string());
148        }
149
150        // Export edge labels
151        if self.config.export_edge_labels {
152            self.export_edge_labels(graph, output_dir)?;
153            files.push("edge_labels.npy".to_string());
154        }
155
156        // Export masks
157        if self.config.export_masks {
158            self.export_masks(graph, output_dir)?;
159            files.push("train_mask.npy".to_string());
160            files.push("val_mask.npy".to_string());
161            files.push("test_mask.npy".to_string());
162        }
163
164        // Export node type indices (for heterogeneous support)
165        if self.config.heterogeneous {
166            self.export_node_types(graph, output_dir)?;
167            files.push("node_type_indices.npy".to_string());
168            self.export_edge_types(graph, output_dir)?;
169            files.push("edge_type_indices.npy".to_string());
170        }
171
172        // Compute node/edge type mappings with counts
173        let node_types: HashMap<String, usize> = graph
174            .nodes_by_type
175            .iter()
176            .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
177            .collect();
178
179        let edge_types: HashMap<String, usize> = graph
180            .edges_by_type
181            .iter()
182            .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
183            .collect();
184
185        // Compute statistics
186        statistics.insert("density".to_string(), graph.metadata.density);
187        statistics.insert(
188            "anomalous_node_ratio".to_string(),
189            graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
190        );
191        statistics.insert(
192            "anomalous_edge_ratio".to_string(),
193            graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
194        );
195
196        // Create metadata
197        let metadata = DGLMetadata {
198            name: graph.name.clone(),
199            num_nodes: graph.node_count(),
200            num_edges: graph.edge_count(),
201            node_feature_dim: graph.metadata.node_feature_dim,
202            edge_feature_dim: graph.metadata.edge_feature_dim,
203            num_node_classes: 2, // Normal/Anomaly
204            num_edge_classes: 2,
205            node_types,
206            edge_types,
207            is_directed: true,
208            is_heterogeneous: self.config.heterogeneous,
209            edge_format: "COO".to_string(),
210            files,
211            statistics,
212        };
213
214        // Write metadata
215        let metadata_path = output_dir.join("metadata.json");
216        let file = File::create(metadata_path)?;
217        serde_json::to_writer_pretty(file, &metadata)?;
218
219        // Write Python loader script
220        self.write_loader_script(output_dir)?;
221
222        // Write pickle helper script if configured
223        if self.config.include_pickle_script {
224            self.write_pickle_script(output_dir)?;
225        }
226
227        Ok(metadata)
228    }
229
230    /// Exports edge index as COO format [num_edges, 2] array.
231    fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
232        let (sources, targets) = graph.edge_index();
233
234        // Create node ID to index mapping
235        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
236        node_ids.sort();
237        let id_to_idx: HashMap<_, _> = node_ids
238            .iter()
239            .enumerate()
240            .map(|(i, &id)| (id, i))
241            .collect();
242
243        // Create COO format: [num_edges, 2] where each row is (src, dst)
244        let num_edges = sources.len();
245        let coo_data: Vec<Vec<i64>> = (0..num_edges)
246            .map(|i| {
247                let src = *id_to_idx.get(&sources[i]).unwrap_or(&0) as i64;
248                let dst = *id_to_idx.get(&targets[i]).unwrap_or(&0) as i64;
249                vec![src, dst]
250            })
251            .collect();
252
253        // Write as NPY format [num_edges, 2]
254        let path = output_dir.join("edge_index.npy");
255        self.write_npy_2d_i64(&path, &coo_data)?;
256
257        Ok(())
258    }
259
260    /// Exports node features.
261    fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
262        let features = graph.node_features();
263        let dim = features.first().map(|f| f.len()).unwrap_or(0);
264
265        let path = output_dir.join("node_features.npy");
266        self.write_npy_2d_f64(&path, &features)?;
267
268        Ok(dim)
269    }
270
271    /// Exports edge features.
272    fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
273        let features = graph.edge_features();
274        let dim = features.first().map(|f| f.len()).unwrap_or(0);
275
276        let path = output_dir.join("edge_features.npy");
277        self.write_npy_2d_f64(&path, &features)?;
278
279        Ok(dim)
280    }
281
282    /// Exports node labels (anomaly flags).
283    fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
284        let labels: Vec<i64> = graph
285            .node_anomaly_mask()
286            .iter()
287            .map(|&b| if b { 1 } else { 0 })
288            .collect();
289
290        let path = output_dir.join("node_labels.npy");
291        self.write_npy_1d_i64(&path, &labels)?;
292
293        Ok(())
294    }
295
296    /// Exports edge labels (anomaly flags).
297    fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
298        let labels: Vec<i64> = graph
299            .edge_anomaly_mask()
300            .iter()
301            .map(|&b| if b { 1 } else { 0 })
302            .collect();
303
304        let path = output_dir.join("edge_labels.npy");
305        self.write_npy_1d_i64(&path, &labels)?;
306
307        Ok(())
308    }
309
310    /// Exports train/val/test masks.
311    fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
312        let n = graph.node_count();
313        let mut rng = SimpleRng::new(self.config.seed);
314
315        let train_size = (n as f64 * self.config.train_ratio) as usize;
316        let val_size = (n as f64 * self.config.val_ratio) as usize;
317
318        // Create shuffled indices
319        let mut indices: Vec<usize> = (0..n).collect();
320        for i in (1..n).rev() {
321            let j = (rng.next() % (i as u64 + 1)) as usize;
322            indices.swap(i, j);
323        }
324
325        // Create masks
326        let mut train_mask = vec![false; n];
327        let mut val_mask = vec![false; n];
328        let mut test_mask = vec![false; n];
329
330        for (i, &idx) in indices.iter().enumerate() {
331            if i < train_size {
332                train_mask[idx] = true;
333            } else if i < train_size + val_size {
334                val_mask[idx] = true;
335            } else {
336                test_mask[idx] = true;
337            }
338        }
339
340        // Write masks
341        self.write_npy_1d_bool(&output_dir.join("train_mask.npy"), &train_mask)?;
342        self.write_npy_1d_bool(&output_dir.join("val_mask.npy"), &val_mask)?;
343        self.write_npy_1d_bool(&output_dir.join("test_mask.npy"), &test_mask)?;
344
345        Ok(())
346    }
347
348    /// Exports node type indices for heterogeneous graphs.
349    fn export_node_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
350        // Create type mapping
351        let type_to_idx: HashMap<_, _> = graph
352            .nodes_by_type
353            .keys()
354            .enumerate()
355            .map(|(i, t)| (t.clone(), i as i64))
356            .collect();
357
358        // Get node IDs in sorted order for consistent indexing
359        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
360        node_ids.sort();
361
362        // Map each node to its type index
363        let type_indices: Vec<i64> = node_ids
364            .iter()
365            .map(|id| {
366                let node = graph.nodes.get(id).unwrap();
367                *type_to_idx.get(&node.node_type).unwrap_or(&0)
368            })
369            .collect();
370
371        let path = output_dir.join("node_type_indices.npy");
372        self.write_npy_1d_i64(&path, &type_indices)?;
373
374        Ok(())
375    }
376
377    /// Exports edge type indices for heterogeneous graphs.
378    fn export_edge_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
379        // Create type mapping
380        let type_to_idx: HashMap<_, _> = graph
381            .edges_by_type
382            .keys()
383            .enumerate()
384            .map(|(i, t)| (t.clone(), i as i64))
385            .collect();
386
387        // Get edge IDs in sorted order for consistent indexing
388        let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
389        edge_ids.sort();
390
391        // Map each edge to its type index
392        let type_indices: Vec<i64> = edge_ids
393            .iter()
394            .map(|id| {
395                let edge = graph.edges.get(id).unwrap();
396                *type_to_idx.get(&edge.edge_type).unwrap_or(&0)
397            })
398            .collect();
399
400        let path = output_dir.join("edge_type_indices.npy");
401        self.write_npy_1d_i64(&path, &type_indices)?;
402
403        Ok(())
404    }
405
406    /// Writes a 1D array of i64 in NPY format.
407    fn write_npy_1d_i64(&self, path: &Path, data: &[i64]) -> std::io::Result<()> {
408        let file = File::create(path)?;
409        let mut writer = BufWriter::new(file);
410
411        // NPY header
412        let shape = format!("({},)", data.len());
413        self.write_npy_header(&mut writer, "<i8", &shape)?;
414
415        // Data
416        for &val in data {
417            writer.write_all(&val.to_le_bytes())?;
418        }
419
420        Ok(())
421    }
422
423    /// Writes a 1D array of bool in NPY format.
424    fn write_npy_1d_bool(&self, path: &Path, data: &[bool]) -> std::io::Result<()> {
425        let file = File::create(path)?;
426        let mut writer = BufWriter::new(file);
427
428        // NPY header
429        let shape = format!("({},)", data.len());
430        self.write_npy_header(&mut writer, "|b1", &shape)?;
431
432        // Data
433        for &val in data {
434            writer.write_all(&[if val { 1u8 } else { 0u8 }])?;
435        }
436
437        Ok(())
438    }
439
440    /// Writes a 2D array of i64 in NPY format.
441    fn write_npy_2d_i64(&self, path: &Path, data: &[Vec<i64>]) -> std::io::Result<()> {
442        let file = File::create(path)?;
443        let mut writer = BufWriter::new(file);
444
445        let rows = data.len();
446        let cols = data.first().map(|r| r.len()).unwrap_or(0);
447
448        // NPY header
449        let shape = format!("({}, {})", rows, cols);
450        self.write_npy_header(&mut writer, "<i8", &shape)?;
451
452        // Data (row-major)
453        for row in data {
454            for &val in row {
455                writer.write_all(&val.to_le_bytes())?;
456            }
457            // Pad short rows if needed
458            for _ in row.len()..cols {
459                writer.write_all(&0_i64.to_le_bytes())?;
460            }
461        }
462
463        Ok(())
464    }
465
466    /// Writes a 2D array of f64 in NPY format.
467    fn write_npy_2d_f64(&self, path: &Path, data: &[Vec<f64>]) -> std::io::Result<()> {
468        let file = File::create(path)?;
469        let mut writer = BufWriter::new(file);
470
471        let rows = data.len();
472        let cols = data.first().map(|r| r.len()).unwrap_or(0);
473
474        // NPY header
475        let shape = format!("({}, {})", rows, cols);
476        self.write_npy_header(&mut writer, "<f8", &shape)?;
477
478        // Data (row-major)
479        for row in data {
480            for &val in row {
481                writer.write_all(&val.to_le_bytes())?;
482            }
483            // Pad short rows with zeros
484            for _ in row.len()..cols {
485                writer.write_all(&0.0_f64.to_le_bytes())?;
486            }
487        }
488
489        Ok(())
490    }
491
492    /// Writes NPY header.
493    fn write_npy_header<W: Write>(
494        &self,
495        writer: &mut W,
496        dtype: &str,
497        shape: &str,
498    ) -> std::io::Result<()> {
499        // Magic number and version
500        writer.write_all(&[0x93])?; // \x93
501        writer.write_all(b"NUMPY")?;
502        writer.write_all(&[0x01, 0x00])?; // Version 1.0
503
504        // Header dict
505        let header = format!(
506            "{{'descr': '{}', 'fortran_order': False, 'shape': {} }}",
507            dtype, shape
508        );
509
510        // Pad header to multiple of 64 bytes (including magic, version, header_len)
511        let header_len = header.len();
512        let total_len = 10 + header_len + 1; // magic(6) + version(2) + header_len(2) + header + newline
513        let padding = (64 - (total_len % 64)) % 64;
514        let padded_len = header_len + 1 + padding;
515
516        writer.write_all(&(padded_len as u16).to_le_bytes())?;
517        writer.write_all(header.as_bytes())?;
518        for _ in 0..padding {
519            writer.write_all(b" ")?;
520        }
521        writer.write_all(b"\n")?;
522
523        Ok(())
524    }
525
526    /// Writes a Python loader script for DGL.
527    fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
528        let script = r#"#!/usr/bin/env python3
529"""
530DGL (Deep Graph Library) Data Loader
531
532Auto-generated loader for graph data exported from synth-graph.
533Supports both homogeneous and heterogeneous graph loading.
534"""
535
536import json
537import numpy as np
538from pathlib import Path
539
540try:
541    import torch
542    HAS_TORCH = True
543except ImportError:
544    HAS_TORCH = False
545    print("Warning: torch not installed. Install with: pip install torch")
546
547try:
548    import dgl
549    HAS_DGL = True
550except ImportError:
551    HAS_DGL = False
552    print("Warning: dgl not installed. Install with: pip install dgl")
553
554
555def load_graph(data_dir: str = ".") -> "dgl.DGLGraph":
556    """Load graph data into a DGL graph object.
557
558    Args:
559        data_dir: Directory containing the exported graph data.
560
561    Returns:
562        DGL graph with node features, edge features, and labels attached.
563    """
564    data_dir = Path(data_dir)
565
566    # Load metadata
567    with open(data_dir / "metadata.json") as f:
568        metadata = json.load(f)
569
570    # Load edge index (COO format: [num_edges, 2])
571    edge_index = np.load(data_dir / "edge_index.npy")
572    src = edge_index[:, 0]
573    dst = edge_index[:, 1]
574
575    num_nodes = metadata["num_nodes"]
576
577    if not HAS_DGL:
578        # Return dict if DGL not available
579        result = {
580            "src": src,
581            "dst": dst,
582            "num_nodes": num_nodes,
583            "metadata": metadata,
584        }
585
586        # Load optional arrays
587        if (data_dir / "node_features.npy").exists():
588            result["node_features"] = np.load(data_dir / "node_features.npy")
589        if (data_dir / "edge_features.npy").exists():
590            result["edge_features"] = np.load(data_dir / "edge_features.npy")
591        if (data_dir / "node_labels.npy").exists():
592            result["node_labels"] = np.load(data_dir / "node_labels.npy")
593        if (data_dir / "edge_labels.npy").exists():
594            result["edge_labels"] = np.load(data_dir / "edge_labels.npy")
595        if (data_dir / "train_mask.npy").exists():
596            result["train_mask"] = np.load(data_dir / "train_mask.npy")
597            result["val_mask"] = np.load(data_dir / "val_mask.npy")
598            result["test_mask"] = np.load(data_dir / "test_mask.npy")
599
600        return result
601
602    # Create DGL graph
603    g = dgl.graph((src, dst), num_nodes=num_nodes)
604
605    # Load and attach node features
606    node_features_path = data_dir / "node_features.npy"
607    if node_features_path.exists():
608        node_features = np.load(node_features_path)
609        if HAS_TORCH:
610            g.ndata['feat'] = torch.from_numpy(node_features).float()
611        else:
612            g.ndata['feat'] = node_features
613
614    # Load and attach edge features
615    edge_features_path = data_dir / "edge_features.npy"
616    if edge_features_path.exists():
617        edge_features = np.load(edge_features_path)
618        if HAS_TORCH:
619            g.edata['feat'] = torch.from_numpy(edge_features).float()
620        else:
621            g.edata['feat'] = edge_features
622
623    # Load and attach node labels
624    node_labels_path = data_dir / "node_labels.npy"
625    if node_labels_path.exists():
626        node_labels = np.load(node_labels_path)
627        if HAS_TORCH:
628            g.ndata['label'] = torch.from_numpy(node_labels).long()
629        else:
630            g.ndata['label'] = node_labels
631
632    # Load and attach edge labels
633    edge_labels_path = data_dir / "edge_labels.npy"
634    if edge_labels_path.exists():
635        edge_labels = np.load(edge_labels_path)
636        if HAS_TORCH:
637            g.edata['label'] = torch.from_numpy(edge_labels).long()
638        else:
639            g.edata['label'] = edge_labels
640
641    # Load and attach masks
642    if (data_dir / "train_mask.npy").exists():
643        train_mask = np.load(data_dir / "train_mask.npy")
644        val_mask = np.load(data_dir / "val_mask.npy")
645        test_mask = np.load(data_dir / "test_mask.npy")
646
647        if HAS_TORCH:
648            g.ndata['train_mask'] = torch.from_numpy(train_mask).bool()
649            g.ndata['val_mask'] = torch.from_numpy(val_mask).bool()
650            g.ndata['test_mask'] = torch.from_numpy(test_mask).bool()
651        else:
652            g.ndata['train_mask'] = train_mask
653            g.ndata['val_mask'] = val_mask
654            g.ndata['test_mask'] = test_mask
655
656    # Store metadata as graph attribute
657    g.metadata = metadata
658
659    return g
660
661
662def load_heterogeneous_graph(data_dir: str = ".") -> "dgl.DGLHeteroGraph":
663    """Load graph data into a DGL heterogeneous graph.
664
665    This function handles graphs with multiple node and edge types.
666
667    Args:
668        data_dir: Directory containing the exported graph data.
669
670    Returns:
671        DGL heterogeneous graph.
672    """
673    data_dir = Path(data_dir)
674
675    # Load metadata
676    with open(data_dir / "metadata.json") as f:
677        metadata = json.load(f)
678
679    if not metadata.get("is_heterogeneous", False):
680        print("Warning: Graph was not exported as heterogeneous. Using homogeneous loader.")
681        return load_graph(data_dir)
682
683    if not HAS_DGL:
684        raise ImportError("DGL is required for heterogeneous graph loading")
685
686    # Load edge index and type indices
687    edge_index = np.load(data_dir / "edge_index.npy")
688    edge_types = np.load(data_dir / "edge_type_indices.npy")
689    node_types = np.load(data_dir / "node_type_indices.npy")
690
691    # Get type names from metadata
692    node_type_names = list(metadata["node_types"].keys())
693    edge_type_names = list(metadata["edge_types"].keys())
694
695    # Build edge dict for heterogeneous graph
696    edge_dict = {}
697    for etype_idx, etype_name in enumerate(edge_type_names):
698        mask = edge_types == etype_idx
699        if mask.any():
700            src = edge_index[mask, 0]
701            dst = edge_index[mask, 1]
702            # For heterogeneous, we need to specify (src_type, edge_type, dst_type)
703            # Using simplified convention: (node_type, edge_type, node_type)
704            edge_dict[(node_type_names[0] if node_type_names else 'node',
705                      etype_name,
706                      node_type_names[0] if node_type_names else 'node')] = (src, dst)
707
708    # Create heterogeneous graph
709    g = dgl.heterograph(edge_dict) if edge_dict else dgl.graph(([], []))
710    g.metadata = metadata
711
712    return g
713
714
715def print_summary(data_dir: str = "."):
716    """Print summary of the graph data."""
717    data_dir = Path(data_dir)
718
719    with open(data_dir / "metadata.json") as f:
720        metadata = json.load(f)
721
722    print(f"Graph: {metadata['name']}")
723    print(f"Format: DGL ({metadata['edge_format']} edge format)")
724    print(f"Nodes: {metadata['num_nodes']}")
725    print(f"Edges: {metadata['num_edges']}")
726    print(f"Node feature dim: {metadata['node_feature_dim']}")
727    print(f"Edge feature dim: {metadata['edge_feature_dim']}")
728    print(f"Directed: {metadata['is_directed']}")
729    print(f"Heterogeneous: {metadata['is_heterogeneous']}")
730
731    if metadata['node_types']:
732        print(f"Node types: {metadata['node_types']}")
733    if metadata['edge_types']:
734        print(f"Edge types: {metadata['edge_types']}")
735
736    if metadata['statistics']:
737        print("\nStatistics:")
738        for key, value in metadata['statistics'].items():
739            print(f"  {key}: {value:.4f}")
740
741    if HAS_DGL:
742        print("\nLoading graph...")
743        g = load_graph(data_dir)
744        if hasattr(g, 'num_nodes'):
745            print(f"DGL graph loaded: {g.num_nodes()} nodes, {g.num_edges()} edges")
746            if 'label' in g.ndata:
747                print(f"Anomalous nodes: {g.ndata['label'].sum().item()}")
748
749
750if __name__ == "__main__":
751    import sys
752    data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
753    print_summary(data_dir)
754"#;
755
756        let path = output_dir.join("load_graph.py");
757        let mut file = File::create(path)?;
758        file.write_all(script.as_bytes())?;
759
760        Ok(())
761    }
762
763    /// Writes a helper script for saving/loading DGL graphs as pickle.
764    fn write_pickle_script(&self, output_dir: &Path) -> std::io::Result<()> {
765        let script = r#"#!/usr/bin/env python3
766"""
767DGL Graph Pickle Helper
768
769Utility to save and load DGL graphs as pickle files for faster subsequent loading.
770"""
771
772import pickle
773from pathlib import Path
774
775try:
776    import dgl
777    HAS_DGL = True
778except ImportError:
779    HAS_DGL = False
780
781
782def save_dgl_graph(graph, output_path: str):
783    """Save a DGL graph to a pickle file.
784
785    Args:
786        graph: DGL graph to save.
787        output_path: Path to save the pickle file.
788    """
789    output_path = Path(output_path)
790
791    # Save graph data
792    graph_data = {
793        'num_nodes': graph.num_nodes(),
794        'edges': graph.edges(),
795        'ndata': {k: v.numpy() if hasattr(v, 'numpy') else v
796                  for k, v in graph.ndata.items()},
797        'edata': {k: v.numpy() if hasattr(v, 'numpy') else v
798                  for k, v in graph.edata.items()},
799        'metadata': getattr(graph, 'metadata', {}),
800    }
801
802    with open(output_path, 'wb') as f:
803        pickle.dump(graph_data, f, protocol=pickle.HIGHEST_PROTOCOL)
804
805    print(f"Saved graph to {output_path}")
806
807
808def load_dgl_graph(input_path: str) -> "dgl.DGLGraph":
809    """Load a DGL graph from a pickle file.
810
811    Args:
812        input_path: Path to the pickle file.
813
814    Returns:
815        DGL graph.
816    """
817    if not HAS_DGL:
818        raise ImportError("DGL is required to load graphs")
819
820    import torch
821
822    input_path = Path(input_path)
823
824    with open(input_path, 'rb') as f:
825        graph_data = pickle.load(f)
826
827    # Recreate graph
828    src, dst = graph_data['edges']
829    g = dgl.graph((src, dst), num_nodes=graph_data['num_nodes'])
830
831    # Restore node data
832    for k, v in graph_data['ndata'].items():
833        g.ndata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
834
835    # Restore edge data
836    for k, v in graph_data['edata'].items():
837        g.edata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
838
839    # Restore metadata
840    g.metadata = graph_data.get('metadata', {})
841
842    return g
843
844
845def convert_to_pickle(data_dir: str, output_path: str = None):
846    """Convert exported graph data to pickle format for faster loading.
847
848    Args:
849        data_dir: Directory containing the exported graph data.
850        output_path: Path for output pickle file. Defaults to data_dir/graph.pkl.
851    """
852    from load_graph import load_graph
853
854    data_dir = Path(data_dir)
855    output_path = Path(output_path) if output_path else data_dir / "graph.pkl"
856
857    print(f"Loading graph from {data_dir}...")
858    g = load_graph(str(data_dir))
859
860    if isinstance(g, dict):
861        print("Error: DGL not available, cannot convert to pickle")
862        return
863
864    save_dgl_graph(g, str(output_path))
865    print(f"Graph saved to {output_path}")
866
867
868if __name__ == "__main__":
869    import sys
870
871    if len(sys.argv) < 2:
872        print("Usage:")
873        print("  python pickle_helper.py convert <data_dir> [output_path]")
874        print("  python pickle_helper.py load <pickle_path>")
875        sys.exit(1)
876
877    command = sys.argv[1]
878
879    if command == "convert":
880        data_dir = sys.argv[2] if len(sys.argv) > 2 else "."
881        output_path = sys.argv[3] if len(sys.argv) > 3 else None
882        convert_to_pickle(data_dir, output_path)
883    elif command == "load":
884        pickle_path = sys.argv[2]
885        g = load_dgl_graph(pickle_path)
886        print(f"Loaded graph: {g.num_nodes()} nodes, {g.num_edges()} edges")
887    else:
888        print(f"Unknown command: {command}")
889"#;
890
891        let path = output_dir.join("pickle_helper.py");
892        let mut file = File::create(path)?;
893        file.write_all(script.as_bytes())?;
894
895        Ok(())
896    }
897}
898
899/// Simple random number generator (xorshift64).
900struct SimpleRng {
901    state: u64,
902}
903
904impl SimpleRng {
905    fn new(seed: u64) -> Self {
906        Self {
907            state: if seed == 0 { 1 } else { seed },
908        }
909    }
910
911    fn next(&mut self) -> u64 {
912        let mut x = self.state;
913        x ^= x << 13;
914        x ^= x >> 7;
915        x ^= x << 17;
916        self.state = x;
917        x
918    }
919}
920
921#[cfg(test)]
922mod tests {
923    use super::*;
924    use crate::models::{EdgeType, GraphEdge, GraphNode, GraphType, NodeType};
925    use tempfile::tempdir;
926
927    fn create_test_graph() -> Graph {
928        let mut graph = Graph::new("test_dgl", GraphType::Transaction);
929
930        let n1 = graph.add_node(
931            GraphNode::new(0, NodeType::Account, "1000".to_string(), "Cash".to_string())
932                .with_feature(0.5),
933        );
934        let n2 = graph.add_node(
935            GraphNode::new(0, NodeType::Account, "2000".to_string(), "AP".to_string())
936                .with_feature(0.8),
937        );
938        let n3 = graph.add_node(
939            GraphNode::new(
940                0,
941                NodeType::Company,
942                "ACME".to_string(),
943                "ACME Corp".to_string(),
944            )
945            .with_feature(0.3)
946            .as_anomaly("fraud"),
947        );
948
949        graph.add_edge(
950            GraphEdge::new(0, n1, n2, EdgeType::Transaction)
951                .with_weight(1000.0)
952                .with_feature(6.9),
953        );
954
955        graph.add_edge(
956            GraphEdge::new(0, n2, n3, EdgeType::Ownership)
957                .with_weight(100.0)
958                .with_feature(4.6),
959        );
960
961        graph.compute_statistics();
962        graph
963    }
964
965    #[test]
966    fn test_dgl_export_basic() {
967        let graph = create_test_graph();
968        let dir = tempdir().unwrap();
969
970        let exporter = DGLExporter::new(DGLExportConfig::default());
971        let metadata = exporter.export(&graph, dir.path()).unwrap();
972
973        assert_eq!(metadata.num_nodes, 3);
974        assert_eq!(metadata.num_edges, 2);
975        assert_eq!(metadata.edge_format, "COO");
976        assert!(dir.path().join("edge_index.npy").exists());
977        assert!(dir.path().join("node_features.npy").exists());
978        assert!(dir.path().join("node_labels.npy").exists());
979        assert!(dir.path().join("metadata.json").exists());
980        assert!(dir.path().join("load_graph.py").exists());
981        assert!(dir.path().join("pickle_helper.py").exists());
982    }
983
984    #[test]
985    fn test_dgl_export_heterogeneous() {
986        let graph = create_test_graph();
987        let dir = tempdir().unwrap();
988
989        let config = DGLExportConfig {
990            heterogeneous: true,
991            ..Default::default()
992        };
993        let exporter = DGLExporter::new(config);
994        let metadata = exporter.export(&graph, dir.path()).unwrap();
995
996        assert!(metadata.is_heterogeneous);
997        assert!(dir.path().join("node_type_indices.npy").exists());
998        assert!(dir.path().join("edge_type_indices.npy").exists());
999    }
1000
1001    #[test]
1002    fn test_dgl_export_masks() {
1003        let graph = create_test_graph();
1004        let dir = tempdir().unwrap();
1005
1006        let exporter = DGLExporter::new(DGLExportConfig::default());
1007        let metadata = exporter.export(&graph, dir.path()).unwrap();
1008
1009        assert!(metadata.files.contains(&"train_mask.npy".to_string()));
1010        assert!(metadata.files.contains(&"val_mask.npy".to_string()));
1011        assert!(metadata.files.contains(&"test_mask.npy".to_string()));
1012        assert!(dir.path().join("train_mask.npy").exists());
1013        assert!(dir.path().join("val_mask.npy").exists());
1014        assert!(dir.path().join("test_mask.npy").exists());
1015    }
1016
1017    #[test]
1018    fn test_dgl_coo_format() {
1019        let graph = create_test_graph();
1020        let dir = tempdir().unwrap();
1021
1022        let exporter = DGLExporter::new(DGLExportConfig::default());
1023        exporter.export(&graph, dir.path()).unwrap();
1024
1025        // Verify edge_index file exists and has correct format
1026        // COO format should be [num_edges, 2]
1027        let edge_path = dir.path().join("edge_index.npy");
1028        assert!(edge_path.exists());
1029
1030        // The metadata confirms format
1031        let metadata_path = dir.path().join("metadata.json");
1032        let metadata: DGLMetadata =
1033            serde_json::from_reader(File::open(metadata_path).unwrap()).unwrap();
1034        assert_eq!(metadata.edge_format, "COO");
1035    }
1036
1037    #[test]
1038    fn test_dgl_export_no_masks() {
1039        let graph = create_test_graph();
1040        let dir = tempdir().unwrap();
1041
1042        let config = DGLExportConfig {
1043            export_masks: false,
1044            ..Default::default()
1045        };
1046        let exporter = DGLExporter::new(config);
1047        let metadata = exporter.export(&graph, dir.path()).unwrap();
1048
1049        assert!(!metadata.files.contains(&"train_mask.npy".to_string()));
1050        assert!(!dir.path().join("train_mask.npy").exists());
1051    }
1052
1053    #[test]
1054    fn test_dgl_export_minimal() {
1055        let graph = create_test_graph();
1056        let dir = tempdir().unwrap();
1057
1058        let config = DGLExportConfig {
1059            export_node_features: false,
1060            export_edge_features: false,
1061            export_node_labels: false,
1062            export_edge_labels: false,
1063            export_masks: false,
1064            include_pickle_script: false,
1065            ..Default::default()
1066        };
1067        let exporter = DGLExporter::new(config);
1068        let metadata = exporter.export(&graph, dir.path()).unwrap();
1069
1070        // Only edge_index and loader script should exist
1071        assert_eq!(metadata.files.len(), 1); // Only edge_index.npy
1072        assert!(dir.path().join("edge_index.npy").exists());
1073        assert!(dir.path().join("load_graph.py").exists()); // Loader always generated
1074        assert!(dir.path().join("metadata.json").exists());
1075        assert!(!dir.path().join("pickle_helper.py").exists());
1076    }
1077
1078    #[test]
1079    fn test_dgl_statistics() {
1080        let graph = create_test_graph();
1081        let dir = tempdir().unwrap();
1082
1083        let exporter = DGLExporter::new(DGLExportConfig::default());
1084        let metadata = exporter.export(&graph, dir.path()).unwrap();
1085
1086        // Should have density and anomaly ratios
1087        assert!(metadata.statistics.contains_key("density"));
1088        assert!(metadata.statistics.contains_key("anomalous_node_ratio"));
1089        assert!(metadata.statistics.contains_key("anomalous_edge_ratio"));
1090
1091        // One of three nodes is anomalous
1092        let node_ratio = metadata.statistics.get("anomalous_node_ratio").unwrap();
1093        assert!((*node_ratio - 1.0 / 3.0).abs() < 0.01);
1094    }
1095}