Skip to main content

datasynth_graph/exporters/
dgl.rs

1//! Deep Graph Library (DGL) exporter.
2//!
3//! Exports graph data in formats compatible with DGL:
4//! - NumPy arrays (.npy) for node/edge features and labels
5//! - COO format edge index [num_edges, 2] (differs from PyG's [2, num_edges])
6//! - JSON metadata for graph information
7//!
8//! The exported data can be loaded in Python with:
9//! ```python
10//! import numpy as np
11//! import torch
12//! import dgl
13//!
14//! node_features = torch.from_numpy(np.load('node_features.npy'))
15//! edge_index = np.load('edge_index.npy')  # [num_edges, 2] COO format
16//! src, dst = edge_index[:, 0], edge_index[:, 1]
17//!
18//! g = dgl.graph((src, dst))
19//! g.ndata['feat'] = node_features
20//! ```
21//!
22//! For heterogeneous graphs, DGL uses separate arrays per node/edge type.
23
24use std::collections::HashMap;
25use std::fs::{self, File};
26use std::io::Write;
27use std::path::Path;
28
29use serde::{Deserialize, Serialize};
30
31use crate::exporters::common::{CommonExportConfig, CommonGraphMetadata};
32use crate::exporters::npy_writer;
33use crate::models::Graph;
34
35/// Configuration for DGL export.
36#[derive(Debug, Clone)]
37pub struct DGLExportConfig {
38    /// Common export settings (features, labels, masks, splits, seed).
39    pub common: CommonExportConfig,
40    /// Export as heterogeneous graph (separate files per type).
41    pub heterogeneous: bool,
42    /// Include Python pickle helper script.
43    pub include_pickle_script: bool,
44}
45
46impl Default for DGLExportConfig {
47    fn default() -> Self {
48        Self {
49            common: CommonExportConfig::default(),
50            heterogeneous: false,
51            include_pickle_script: true,
52        }
53    }
54}
55
56/// Metadata about the exported DGL data.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct DGLMetadata {
59    /// Common graph metadata fields.
60    #[serde(flatten)]
61    pub common: CommonGraphMetadata,
62    /// Whether export is heterogeneous.
63    pub is_heterogeneous: bool,
64    /// Edge index format (COO).
65    pub edge_format: String,
66}
67
68/// DGL graph exporter.
69pub struct DGLExporter {
70    config: DGLExportConfig,
71}
72
73impl DGLExporter {
74    /// Creates a new DGL exporter.
75    pub fn new(config: DGLExportConfig) -> Self {
76        Self { config }
77    }
78
79    /// Exports a graph to DGL format.
80    pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<DGLMetadata> {
81        fs::create_dir_all(output_dir)?;
82
83        let mut files = Vec::new();
84        let mut statistics = HashMap::new();
85
86        // Export edge index in COO format [num_edges, 2]
87        self.export_edge_index(graph, output_dir)?;
88        files.push("edge_index.npy".to_string());
89
90        // Export node features
91        if self.config.common.export_node_features {
92            let dim = self.export_node_features(graph, output_dir)?;
93            files.push("node_features.npy".to_string());
94            statistics.insert("node_feature_dim".to_string(), dim as f64);
95        }
96
97        // Export edge features
98        if self.config.common.export_edge_features {
99            let dim = self.export_edge_features(graph, output_dir)?;
100            files.push("edge_features.npy".to_string());
101            statistics.insert("edge_feature_dim".to_string(), dim as f64);
102        }
103
104        // Export node labels
105        if self.config.common.export_node_labels {
106            self.export_node_labels(graph, output_dir)?;
107            files.push("node_labels.npy".to_string());
108        }
109
110        // Export edge labels
111        if self.config.common.export_edge_labels {
112            self.export_edge_labels(graph, output_dir)?;
113            files.push("edge_labels.npy".to_string());
114        }
115
116        // Export masks
117        if self.config.common.export_masks {
118            self.export_masks(graph, output_dir)?;
119            files.push("train_mask.npy".to_string());
120            files.push("val_mask.npy".to_string());
121            files.push("test_mask.npy".to_string());
122        }
123
124        // Export node type indices (for heterogeneous support)
125        if self.config.heterogeneous {
126            self.export_node_types(graph, output_dir)?;
127            files.push("node_type_indices.npy".to_string());
128            self.export_edge_types(graph, output_dir)?;
129            files.push("edge_type_indices.npy".to_string());
130        }
131
132        // Compute node/edge type mappings with counts
133        let node_types: HashMap<String, usize> = graph
134            .nodes_by_type
135            .iter()
136            .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
137            .collect();
138
139        let edge_types: HashMap<String, usize> = graph
140            .edges_by_type
141            .iter()
142            .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
143            .collect();
144
145        // Compute statistics
146        statistics.insert("density".to_string(), graph.metadata.density);
147        statistics.insert(
148            "anomalous_node_ratio".to_string(),
149            graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
150        );
151        statistics.insert(
152            "anomalous_edge_ratio".to_string(),
153            graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
154        );
155
156        // Create metadata
157        let metadata = DGLMetadata {
158            common: CommonGraphMetadata {
159                name: graph.name.clone(),
160                num_nodes: graph.node_count(),
161                num_edges: graph.edge_count(),
162                node_feature_dim: graph.metadata.node_feature_dim,
163                edge_feature_dim: graph.metadata.edge_feature_dim,
164                num_node_classes: 2, // Normal/Anomaly
165                num_edge_classes: 2,
166                node_types,
167                edge_types,
168                is_directed: true,
169                files,
170                statistics,
171            },
172            is_heterogeneous: self.config.heterogeneous,
173            edge_format: "COO".to_string(),
174        };
175
176        // Write metadata
177        let metadata_path = output_dir.join("metadata.json");
178        let file = File::create(metadata_path)?;
179        serde_json::to_writer_pretty(file, &metadata)?;
180
181        // Write Python loader script
182        self.write_loader_script(output_dir)?;
183
184        // Write pickle helper script if configured
185        if self.config.include_pickle_script {
186            self.write_pickle_script(output_dir)?;
187        }
188
189        Ok(metadata)
190    }
191
192    /// Exports edge index as COO format [num_edges, 2] array.
193    fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
194        let (sources, targets) = graph.edge_index();
195
196        // Create node ID to index mapping
197        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
198        node_ids.sort();
199        let id_to_idx: HashMap<_, _> = node_ids
200            .iter()
201            .enumerate()
202            .map(|(i, &id)| (id, i))
203            .collect();
204
205        // Create COO format: [num_edges, 2] where each row is (src, dst)
206        let num_edges = sources.len();
207        let mut coo_data: Vec<Vec<i64>> = Vec::with_capacity(num_edges);
208        let mut skipped_edges = 0usize;
209
210        for i in 0..num_edges {
211            match (id_to_idx.get(&sources[i]), id_to_idx.get(&targets[i])) {
212                (Some(&s), Some(&d)) => {
213                    coo_data.push(vec![s as i64, d as i64]);
214                }
215                _ => {
216                    skipped_edges += 1;
217                }
218            }
219        }
220        if skipped_edges > 0 {
221            tracing::warn!(
222                "DGL export: skipped {} edges with missing node IDs",
223                skipped_edges
224            );
225        }
226
227        // Write as NPY format [num_edges, 2]
228        let path = output_dir.join("edge_index.npy");
229        npy_writer::write_npy_2d_i64(&path, &coo_data)?;
230
231        Ok(())
232    }
233
234    /// Exports node features.
235    fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
236        let features = graph.node_features();
237        let dim = features.first().map(|f| f.len()).unwrap_or(0);
238
239        let path = output_dir.join("node_features.npy");
240        npy_writer::write_npy_2d_f64(&path, &features)?;
241
242        Ok(dim)
243    }
244
245    /// Exports edge features.
246    fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
247        let features = graph.edge_features();
248        let dim = features.first().map(|f| f.len()).unwrap_or(0);
249
250        let path = output_dir.join("edge_features.npy");
251        npy_writer::write_npy_2d_f64(&path, &features)?;
252
253        Ok(dim)
254    }
255
256    /// Exports node labels (anomaly flags).
257    fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
258        let labels: Vec<i64> = graph
259            .node_anomaly_mask()
260            .iter()
261            .map(|&b| if b { 1 } else { 0 })
262            .collect();
263
264        let path = output_dir.join("node_labels.npy");
265        npy_writer::write_npy_1d_i64(&path, &labels)?;
266
267        Ok(())
268    }
269
270    /// Exports edge labels (anomaly flags).
271    fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
272        let labels: Vec<i64> = graph
273            .edge_anomaly_mask()
274            .iter()
275            .map(|&b| if b { 1 } else { 0 })
276            .collect();
277
278        let path = output_dir.join("edge_labels.npy");
279        npy_writer::write_npy_1d_i64(&path, &labels)?;
280
281        Ok(())
282    }
283
284    /// Exports train/val/test masks.
285    fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
286        npy_writer::export_masks(
287            output_dir,
288            graph.node_count(),
289            self.config.common.seed,
290            self.config.common.train_ratio,
291            self.config.common.val_ratio,
292        )
293    }
294
295    /// Exports node type indices for heterogeneous graphs.
296    fn export_node_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
297        // Create type mapping
298        let type_to_idx: HashMap<_, _> = graph
299            .nodes_by_type
300            .keys()
301            .enumerate()
302            .map(|(i, t)| (t.clone(), i as i64))
303            .collect();
304
305        // Get node IDs in sorted order for consistent indexing
306        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
307        node_ids.sort();
308
309        // Map each node to its type index
310        let type_indices: Vec<i64> = node_ids
311            .iter()
312            .map(|id| {
313                let node = graph.nodes.get(id).expect("node ID from keys()");
314                *type_to_idx.get(&node.node_type).unwrap_or_else(|| {
315                    tracing::warn!(
316                        "Unknown node type '{:?}', defaulting to index 0",
317                        node.node_type
318                    );
319                    &0
320                })
321            })
322            .collect();
323
324        let path = output_dir.join("node_type_indices.npy");
325        npy_writer::write_npy_1d_i64(&path, &type_indices)?;
326
327        Ok(())
328    }
329
330    /// Exports edge type indices for heterogeneous graphs.
331    fn export_edge_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
332        // Create type mapping
333        let type_to_idx: HashMap<_, _> = graph
334            .edges_by_type
335            .keys()
336            .enumerate()
337            .map(|(i, t)| (t.clone(), i as i64))
338            .collect();
339
340        // Get edge IDs in sorted order for consistent indexing
341        let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
342        edge_ids.sort();
343
344        // Map each edge to its type index
345        let type_indices: Vec<i64> = edge_ids
346            .iter()
347            .map(|id| {
348                let edge = graph.edges.get(id).expect("edge ID from keys()");
349                *type_to_idx.get(&edge.edge_type).unwrap_or_else(|| {
350                    tracing::warn!(
351                        "Unknown edge type '{:?}', defaulting to index 0",
352                        edge.edge_type
353                    );
354                    &0
355                })
356            })
357            .collect();
358
359        let path = output_dir.join("edge_type_indices.npy");
360        npy_writer::write_npy_1d_i64(&path, &type_indices)?;
361
362        Ok(())
363    }
364
365    /// Writes a Python loader script for DGL.
366    fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
367        let script = r#"#!/usr/bin/env python3
368"""
369DGL (Deep Graph Library) Data Loader
370
371Auto-generated loader for graph data exported from synth-graph.
372Supports both homogeneous and heterogeneous graph loading.
373"""
374
375import json
376import numpy as np
377from pathlib import Path
378
379try:
380    import torch
381    HAS_TORCH = True
382except ImportError:
383    HAS_TORCH = False
384    print("Warning: torch not installed. Install with: pip install torch")
385
386try:
387    import dgl
388    HAS_DGL = True
389except ImportError:
390    HAS_DGL = False
391    print("Warning: dgl not installed. Install with: pip install dgl")
392
393
394def load_graph(data_dir: str = ".") -> "dgl.DGLGraph":
395    """Load graph data into a DGL graph object.
396
397    Args:
398        data_dir: Directory containing the exported graph data.
399
400    Returns:
401        DGL graph with node features, edge features, and labels attached.
402    """
403    data_dir = Path(data_dir)
404
405    # Load metadata
406    with open(data_dir / "metadata.json") as f:
407        metadata = json.load(f)
408
409    # Load edge index (COO format: [num_edges, 2])
410    edge_index = np.load(data_dir / "edge_index.npy")
411    src = edge_index[:, 0]
412    dst = edge_index[:, 1]
413
414    num_nodes = metadata["num_nodes"]
415
416    if not HAS_DGL:
417        # Return dict if DGL not available
418        result = {
419            "src": src,
420            "dst": dst,
421            "num_nodes": num_nodes,
422            "metadata": metadata,
423        }
424
425        # Load optional arrays
426        if (data_dir / "node_features.npy").exists():
427            result["node_features"] = np.load(data_dir / "node_features.npy")
428        if (data_dir / "edge_features.npy").exists():
429            result["edge_features"] = np.load(data_dir / "edge_features.npy")
430        if (data_dir / "node_labels.npy").exists():
431            result["node_labels"] = np.load(data_dir / "node_labels.npy")
432        if (data_dir / "edge_labels.npy").exists():
433            result["edge_labels"] = np.load(data_dir / "edge_labels.npy")
434        if (data_dir / "train_mask.npy").exists():
435            result["train_mask"] = np.load(data_dir / "train_mask.npy")
436            result["val_mask"] = np.load(data_dir / "val_mask.npy")
437            result["test_mask"] = np.load(data_dir / "test_mask.npy")
438
439        return result
440
441    # Create DGL graph
442    g = dgl.graph((src, dst), num_nodes=num_nodes)
443
444    # Load and attach node features
445    node_features_path = data_dir / "node_features.npy"
446    if node_features_path.exists():
447        node_features = np.load(node_features_path)
448        if HAS_TORCH:
449            g.ndata['feat'] = torch.from_numpy(node_features).float()
450        else:
451            g.ndata['feat'] = node_features
452
453    # Load and attach edge features
454    edge_features_path = data_dir / "edge_features.npy"
455    if edge_features_path.exists():
456        edge_features = np.load(edge_features_path)
457        if HAS_TORCH:
458            g.edata['feat'] = torch.from_numpy(edge_features).float()
459        else:
460            g.edata['feat'] = edge_features
461
462    # Load and attach node labels
463    node_labels_path = data_dir / "node_labels.npy"
464    if node_labels_path.exists():
465        node_labels = np.load(node_labels_path)
466        if HAS_TORCH:
467            g.ndata['label'] = torch.from_numpy(node_labels).long()
468        else:
469            g.ndata['label'] = node_labels
470
471    # Load and attach edge labels
472    edge_labels_path = data_dir / "edge_labels.npy"
473    if edge_labels_path.exists():
474        edge_labels = np.load(edge_labels_path)
475        if HAS_TORCH:
476            g.edata['label'] = torch.from_numpy(edge_labels).long()
477        else:
478            g.edata['label'] = edge_labels
479
480    # Load and attach masks
481    if (data_dir / "train_mask.npy").exists():
482        train_mask = np.load(data_dir / "train_mask.npy")
483        val_mask = np.load(data_dir / "val_mask.npy")
484        test_mask = np.load(data_dir / "test_mask.npy")
485
486        if HAS_TORCH:
487            g.ndata['train_mask'] = torch.from_numpy(train_mask).bool()
488            g.ndata['val_mask'] = torch.from_numpy(val_mask).bool()
489            g.ndata['test_mask'] = torch.from_numpy(test_mask).bool()
490        else:
491            g.ndata['train_mask'] = train_mask
492            g.ndata['val_mask'] = val_mask
493            g.ndata['test_mask'] = test_mask
494
495    # Store metadata as graph attribute
496    g.metadata = metadata
497
498    return g
499
500
501def load_heterogeneous_graph(data_dir: str = ".") -> "dgl.DGLHeteroGraph":
502    """Load graph data into a DGL heterogeneous graph.
503
504    This function handles graphs with multiple node and edge types.
505
506    Args:
507        data_dir: Directory containing the exported graph data.
508
509    Returns:
510        DGL heterogeneous graph.
511    """
512    data_dir = Path(data_dir)
513
514    # Load metadata
515    with open(data_dir / "metadata.json") as f:
516        metadata = json.load(f)
517
518    if not metadata.get("is_heterogeneous", False):
519        print("Warning: Graph was not exported as heterogeneous. Using homogeneous loader.")
520        return load_graph(data_dir)
521
522    if not HAS_DGL:
523        raise ImportError("DGL is required for heterogeneous graph loading")
524
525    # Load edge index and type indices
526    edge_index = np.load(data_dir / "edge_index.npy")
527    edge_types = np.load(data_dir / "edge_type_indices.npy")
528    node_types = np.load(data_dir / "node_type_indices.npy")
529
530    # Get type names from metadata
531    node_type_names = list(metadata["node_types"].keys())
532    edge_type_names = list(metadata["edge_types"].keys())
533
534    # Build edge dict for heterogeneous graph
535    edge_dict = {}
536    for etype_idx, etype_name in enumerate(edge_type_names):
537        mask = edge_types == etype_idx
538        if mask.any():
539            src = edge_index[mask, 0]
540            dst = edge_index[mask, 1]
541            # For heterogeneous, we need to specify (src_type, edge_type, dst_type)
542            # Using simplified convention: (node_type, edge_type, node_type)
543            edge_dict[(node_type_names[0] if node_type_names else 'node',
544                      etype_name,
545                      node_type_names[0] if node_type_names else 'node')] = (src, dst)
546
547    # Create heterogeneous graph
548    g = dgl.heterograph(edge_dict) if edge_dict else dgl.graph(([], []))
549    g.metadata = metadata
550
551    return g
552
553
554def print_summary(data_dir: str = "."):
555    """Print summary of the graph data."""
556    data_dir = Path(data_dir)
557
558    with open(data_dir / "metadata.json") as f:
559        metadata = json.load(f)
560
561    print(f"Graph: {metadata['name']}")
562    print(f"Format: DGL ({metadata['edge_format']} edge format)")
563    print(f"Nodes: {metadata['num_nodes']}")
564    print(f"Edges: {metadata['num_edges']}")
565    print(f"Node feature dim: {metadata['node_feature_dim']}")
566    print(f"Edge feature dim: {metadata['edge_feature_dim']}")
567    print(f"Directed: {metadata['is_directed']}")
568    print(f"Heterogeneous: {metadata['is_heterogeneous']}")
569
570    if metadata['node_types']:
571        print(f"Node types: {metadata['node_types']}")
572    if metadata['edge_types']:
573        print(f"Edge types: {metadata['edge_types']}")
574
575    if metadata['statistics']:
576        print("\nStatistics:")
577        for key, value in metadata['statistics'].items():
578            print(f"  {key}: {value:.4f}")
579
580    if HAS_DGL:
581        print("\nLoading graph...")
582        g = load_graph(data_dir)
583        if hasattr(g, 'num_nodes'):
584            print(f"DGL graph loaded: {g.num_nodes()} nodes, {g.num_edges()} edges")
585            if 'label' in g.ndata:
586                print(f"Anomalous nodes: {g.ndata['label'].sum().item()}")
587
588
589if __name__ == "__main__":
590    import sys
591    data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
592    print_summary(data_dir)
593"#;
594
595        let path = output_dir.join("load_graph.py");
596        let mut file = File::create(path)?;
597        file.write_all(script.as_bytes())?;
598
599        Ok(())
600    }
601
602    /// Writes a helper script for saving/loading DGL graphs as pickle.
603    fn write_pickle_script(&self, output_dir: &Path) -> std::io::Result<()> {
604        let script = r#"#!/usr/bin/env python3
605"""
606DGL Graph Pickle Helper
607
608Utility to save and load DGL graphs as pickle files for faster subsequent loading.
609"""
610
611import pickle
612from pathlib import Path
613
614try:
615    import dgl
616    HAS_DGL = True
617except ImportError:
618    HAS_DGL = False
619
620
621def save_dgl_graph(graph, output_path: str):
622    """Save a DGL graph to a pickle file.
623
624    Args:
625        graph: DGL graph to save.
626        output_path: Path to save the pickle file.
627    """
628    output_path = Path(output_path)
629
630    # Save graph data
631    graph_data = {
632        'num_nodes': graph.num_nodes(),
633        'edges': graph.edges(),
634        'ndata': {k: v.numpy() if hasattr(v, 'numpy') else v
635                  for k, v in graph.ndata.items()},
636        'edata': {k: v.numpy() if hasattr(v, 'numpy') else v
637                  for k, v in graph.edata.items()},
638        'metadata': getattr(graph, 'metadata', {}),
639    }
640
641    with open(output_path, 'wb') as f:
642        pickle.dump(graph_data, f, protocol=pickle.HIGHEST_PROTOCOL)
643
644    print(f"Saved graph to {output_path}")
645
646
647def load_dgl_graph(input_path: str) -> "dgl.DGLGraph":
648    """Load a DGL graph from a pickle file.
649
650    Args:
651        input_path: Path to the pickle file.
652
653    Returns:
654        DGL graph.
655    """
656    if not HAS_DGL:
657        raise ImportError("DGL is required to load graphs")
658
659    import torch
660
661    input_path = Path(input_path)
662
663    with open(input_path, 'rb') as f:
664        graph_data = pickle.load(f)
665
666    # Recreate graph
667    src, dst = graph_data['edges']
668    g = dgl.graph((src, dst), num_nodes=graph_data['num_nodes'])
669
670    # Restore node data
671    for k, v in graph_data['ndata'].items():
672        g.ndata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
673
674    # Restore edge data
675    for k, v in graph_data['edata'].items():
676        g.edata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
677
678    # Restore metadata
679    g.metadata = graph_data.get('metadata', {})
680
681    return g
682
683
684def convert_to_pickle(data_dir: str, output_path: str = None):
685    """Convert exported graph data to pickle format for faster loading.
686
687    Args:
688        data_dir: Directory containing the exported graph data.
689        output_path: Path for output pickle file. Defaults to data_dir/graph.pkl.
690    """
691    from load_graph import load_graph
692
693    data_dir = Path(data_dir)
694    output_path = Path(output_path) if output_path else data_dir / "graph.pkl"
695
696    print(f"Loading graph from {data_dir}...")
697    g = load_graph(str(data_dir))
698
699    if isinstance(g, dict):
700        print("Error: DGL not available, cannot convert to pickle")
701        return
702
703    save_dgl_graph(g, str(output_path))
704    print(f"Graph saved to {output_path}")
705
706
707if __name__ == "__main__":
708    import sys
709
710    if len(sys.argv) < 2:
711        print("Usage:")
712        print("  python pickle_helper.py convert <data_dir> [output_path]")
713        print("  python pickle_helper.py load <pickle_path>")
714        sys.exit(1)
715
716    command = sys.argv[1]
717
718    if command == "convert":
719        data_dir = sys.argv[2] if len(sys.argv) > 2 else "."
720        output_path = sys.argv[3] if len(sys.argv) > 3 else None
721        convert_to_pickle(data_dir, output_path)
722    elif command == "load":
723        pickle_path = sys.argv[2]
724        g = load_dgl_graph(pickle_path)
725        print(f"Loaded graph: {g.num_nodes()} nodes, {g.num_edges()} edges")
726    else:
727        print(f"Unknown command: {command}")
728"#;
729
730        let path = output_dir.join("pickle_helper.py");
731        let mut file = File::create(path)?;
732        file.write_all(script.as_bytes())?;
733
734        Ok(())
735    }
736}
737
738#[cfg(test)]
739#[allow(clippy::unwrap_used)]
740mod tests {
741    use super::*;
742    use crate::test_helpers::create_test_graph_with_company;
743    use tempfile::tempdir;
744
745    #[test]
746    fn test_dgl_export_basic() {
747        let graph = create_test_graph_with_company();
748        let dir = tempdir().unwrap();
749
750        let exporter = DGLExporter::new(DGLExportConfig::default());
751        let metadata = exporter.export(&graph, dir.path()).unwrap();
752
753        assert_eq!(metadata.common.num_nodes, 3);
754        assert_eq!(metadata.common.num_edges, 2);
755        assert_eq!(metadata.edge_format, "COO");
756        assert!(dir.path().join("edge_index.npy").exists());
757        assert!(dir.path().join("node_features.npy").exists());
758        assert!(dir.path().join("node_labels.npy").exists());
759        assert!(dir.path().join("metadata.json").exists());
760        assert!(dir.path().join("load_graph.py").exists());
761        assert!(dir.path().join("pickle_helper.py").exists());
762    }
763
764    #[test]
765    fn test_dgl_export_heterogeneous() {
766        let graph = create_test_graph_with_company();
767        let dir = tempdir().unwrap();
768
769        let config = DGLExportConfig {
770            heterogeneous: true,
771            ..Default::default()
772        };
773        let exporter = DGLExporter::new(config);
774        let metadata = exporter.export(&graph, dir.path()).unwrap();
775
776        assert!(metadata.is_heterogeneous);
777        assert!(dir.path().join("node_type_indices.npy").exists());
778        assert!(dir.path().join("edge_type_indices.npy").exists());
779    }
780
781    #[test]
782    fn test_dgl_export_masks() {
783        let graph = create_test_graph_with_company();
784        let dir = tempdir().unwrap();
785
786        let exporter = DGLExporter::new(DGLExportConfig::default());
787        let metadata = exporter.export(&graph, dir.path()).unwrap();
788
789        assert!(metadata
790            .common
791            .files
792            .contains(&"train_mask.npy".to_string()));
793        assert!(metadata.common.files.contains(&"val_mask.npy".to_string()));
794        assert!(metadata.common.files.contains(&"test_mask.npy".to_string()));
795        assert!(dir.path().join("train_mask.npy").exists());
796        assert!(dir.path().join("val_mask.npy").exists());
797        assert!(dir.path().join("test_mask.npy").exists());
798    }
799
800    #[test]
801    fn test_dgl_coo_format() {
802        let graph = create_test_graph_with_company();
803        let dir = tempdir().unwrap();
804
805        let exporter = DGLExporter::new(DGLExportConfig::default());
806        exporter.export(&graph, dir.path()).unwrap();
807
808        // Verify edge_index file exists and has correct format
809        // COO format should be [num_edges, 2]
810        let edge_path = dir.path().join("edge_index.npy");
811        assert!(edge_path.exists());
812
813        // The metadata confirms format
814        let metadata_path = dir.path().join("metadata.json");
815        let metadata: DGLMetadata =
816            serde_json::from_reader(File::open(metadata_path).unwrap()).unwrap();
817        assert_eq!(metadata.edge_format, "COO");
818    }
819
820    #[test]
821    fn test_dgl_export_no_masks() {
822        let graph = create_test_graph_with_company();
823        let dir = tempdir().unwrap();
824
825        let config = DGLExportConfig {
826            common: CommonExportConfig {
827                export_masks: false,
828                ..Default::default()
829            },
830            ..Default::default()
831        };
832        let exporter = DGLExporter::new(config);
833        let metadata = exporter.export(&graph, dir.path()).unwrap();
834
835        assert!(!metadata
836            .common
837            .files
838            .contains(&"train_mask.npy".to_string()));
839        assert!(!dir.path().join("train_mask.npy").exists());
840    }
841
842    #[test]
843    fn test_dgl_export_minimal() {
844        let graph = create_test_graph_with_company();
845        let dir = tempdir().unwrap();
846
847        let config = DGLExportConfig {
848            common: CommonExportConfig {
849                export_node_features: false,
850                export_edge_features: false,
851                export_node_labels: false,
852                export_edge_labels: false,
853                export_masks: false,
854                ..Default::default()
855            },
856            include_pickle_script: false,
857            ..Default::default()
858        };
859        let exporter = DGLExporter::new(config);
860        let metadata = exporter.export(&graph, dir.path()).unwrap();
861
862        // Only edge_index and loader script should exist
863        assert_eq!(metadata.common.files.len(), 1); // Only edge_index.npy
864        assert!(dir.path().join("edge_index.npy").exists());
865        assert!(dir.path().join("load_graph.py").exists()); // Loader always generated
866        assert!(dir.path().join("metadata.json").exists());
867        assert!(!dir.path().join("pickle_helper.py").exists());
868    }
869
870    #[test]
871    fn test_dgl_statistics() {
872        let graph = create_test_graph_with_company();
873        let dir = tempdir().unwrap();
874
875        let exporter = DGLExporter::new(DGLExportConfig::default());
876        let metadata = exporter.export(&graph, dir.path()).unwrap();
877
878        // Should have density and anomaly ratios
879        assert!(metadata.common.statistics.contains_key("density"));
880        assert!(metadata
881            .common
882            .statistics
883            .contains_key("anomalous_node_ratio"));
884        assert!(metadata
885            .common
886            .statistics
887            .contains_key("anomalous_edge_ratio"));
888
889        // One of three nodes is anomalous
890        let node_ratio = metadata
891            .common
892            .statistics
893            .get("anomalous_node_ratio")
894            .unwrap();
895        assert!((*node_ratio - 1.0 / 3.0).abs() < 0.01);
896    }
897}