Skip to main content

datasynth_graph/exporters/
dgl.rs

1//! Deep Graph Library (DGL) exporter.
2//!
3//! Exports graph data in formats compatible with DGL:
4//! - NumPy arrays (.npy) for node/edge features and labels
5//! - COO format edge index [num_edges, 2] (differs from PyG's [2, num_edges])
6//! - JSON metadata for graph information
7//!
8//! The exported data can be loaded in Python with:
9//! ```python
10//! import numpy as np
11//! import torch
12//! import dgl
13//!
14//! node_features = torch.from_numpy(np.load('node_features.npy'))
15//! edge_index = np.load('edge_index.npy')  # [num_edges, 2] COO format
16//! src, dst = edge_index[:, 0], edge_index[:, 1]
17//!
18//! g = dgl.graph((src, dst))
19//! g.ndata['feat'] = node_features
20//! ```
21//!
22//! For heterogeneous graphs, DGL uses separate arrays per node/edge type.
23
24use std::collections::HashMap;
25use std::fs::{self, File};
26use std::io::Write;
27use std::path::Path;
28
29use serde::{Deserialize, Serialize};
30
31use crate::exporters::common::{CommonExportConfig, CommonGraphMetadata};
32use crate::exporters::npy_writer;
33use crate::models::Graph;
34
35/// Configuration for DGL export.
36#[derive(Debug, Clone)]
37pub struct DGLExportConfig {
38    /// Common export settings (features, labels, masks, splits, seed).
39    pub common: CommonExportConfig,
40    /// Export as heterogeneous graph (separate files per type).
41    pub heterogeneous: bool,
42    /// Include Python pickle helper script.
43    pub include_pickle_script: bool,
44}
45
46impl Default for DGLExportConfig {
47    fn default() -> Self {
48        Self {
49            common: CommonExportConfig::default(),
50            heterogeneous: false,
51            include_pickle_script: true,
52        }
53    }
54}
55
56/// Metadata about the exported DGL data.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct DGLMetadata {
59    /// Common graph metadata fields.
60    #[serde(flatten)]
61    pub common: CommonGraphMetadata,
62    /// Whether export is heterogeneous.
63    pub is_heterogeneous: bool,
64    /// Edge index format (COO).
65    pub edge_format: String,
66}
67
68/// DGL graph exporter.
69pub struct DGLExporter {
70    config: DGLExportConfig,
71}
72
73impl DGLExporter {
74    /// Creates a new DGL exporter.
75    pub fn new(config: DGLExportConfig) -> Self {
76        Self { config }
77    }
78
79    /// Exports a graph to DGL format.
80    pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<DGLMetadata> {
81        fs::create_dir_all(output_dir)?;
82
83        let mut files = Vec::new();
84        let mut statistics = HashMap::new();
85
86        // Export edge index in COO format [num_edges, 2]
87        self.export_edge_index(graph, output_dir)?;
88        files.push("edge_index.npy".to_string());
89
90        // Export node features
91        if self.config.common.export_node_features {
92            let dim = self.export_node_features(graph, output_dir)?;
93            files.push("node_features.npy".to_string());
94            statistics.insert("node_feature_dim".to_string(), dim as f64);
95        }
96
97        // Export edge features
98        if self.config.common.export_edge_features {
99            let dim = self.export_edge_features(graph, output_dir)?;
100            files.push("edge_features.npy".to_string());
101            statistics.insert("edge_feature_dim".to_string(), dim as f64);
102        }
103
104        // Export node labels
105        if self.config.common.export_node_labels {
106            self.export_node_labels(graph, output_dir)?;
107            files.push("node_labels.npy".to_string());
108        }
109
110        // Export edge labels
111        if self.config.common.export_edge_labels {
112            self.export_edge_labels(graph, output_dir)?;
113            files.push("edge_labels.npy".to_string());
114        }
115
116        // Export masks
117        if self.config.common.export_masks {
118            self.export_masks(graph, output_dir)?;
119            files.push("train_mask.npy".to_string());
120            files.push("val_mask.npy".to_string());
121            files.push("test_mask.npy".to_string());
122        }
123
124        // Export node type indices (for heterogeneous support)
125        if self.config.heterogeneous {
126            self.export_node_types(graph, output_dir)?;
127            files.push("node_type_indices.npy".to_string());
128            self.export_edge_types(graph, output_dir)?;
129            files.push("edge_type_indices.npy".to_string());
130        }
131
132        // Compute node/edge type mappings with counts
133        let node_types: HashMap<String, usize> = graph
134            .nodes_by_type
135            .iter()
136            .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
137            .collect();
138
139        let edge_types: HashMap<String, usize> = graph
140            .edges_by_type
141            .iter()
142            .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
143            .collect();
144
145        // Compute statistics
146        statistics.insert("density".to_string(), graph.metadata.density);
147        statistics.insert(
148            "anomalous_node_ratio".to_string(),
149            graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
150        );
151        statistics.insert(
152            "anomalous_edge_ratio".to_string(),
153            graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
154        );
155
156        // Create metadata
157        let metadata = DGLMetadata {
158            common: CommonGraphMetadata {
159                name: graph.name.clone(),
160                num_nodes: graph.node_count(),
161                num_edges: graph.edge_count(),
162                node_feature_dim: graph.metadata.node_feature_dim,
163                edge_feature_dim: graph.metadata.edge_feature_dim,
164                num_node_classes: 2, // Normal/Anomaly
165                num_edge_classes: 2,
166                node_types,
167                edge_types,
168                is_directed: true,
169                files,
170                statistics,
171            },
172            is_heterogeneous: self.config.heterogeneous,
173            edge_format: "COO".to_string(),
174        };
175
176        // Write metadata
177        let metadata_path = output_dir.join("metadata.json");
178        let file = File::create(metadata_path)?;
179        serde_json::to_writer_pretty(file, &metadata)?;
180
181        // Write Python loader script
182        self.write_loader_script(output_dir)?;
183
184        // Write pickle helper script if configured
185        if self.config.include_pickle_script {
186            self.write_pickle_script(output_dir)?;
187        }
188
189        Ok(metadata)
190    }
191
192    /// Exports edge index as COO format [num_edges, 2] array.
193    fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
194        let (sources, targets) = graph.edge_index();
195
196        // Create node ID to index mapping
197        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
198        node_ids.sort();
199        let id_to_idx: HashMap<_, _> = node_ids
200            .iter()
201            .enumerate()
202            .map(|(i, &id)| (id, i))
203            .collect();
204
205        // Create COO format: [num_edges, 2] where each row is (src, dst)
206        let num_edges = sources.len();
207        let coo_data: Vec<Vec<i64>> = (0..num_edges)
208            .map(|i| {
209                let src = *id_to_idx.get(&sources[i]).unwrap_or(&0) as i64;
210                let dst = *id_to_idx.get(&targets[i]).unwrap_or(&0) as i64;
211                vec![src, dst]
212            })
213            .collect();
214
215        // Write as NPY format [num_edges, 2]
216        let path = output_dir.join("edge_index.npy");
217        npy_writer::write_npy_2d_i64(&path, &coo_data)?;
218
219        Ok(())
220    }
221
222    /// Exports node features.
223    fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
224        let features = graph.node_features();
225        let dim = features.first().map(|f| f.len()).unwrap_or(0);
226
227        let path = output_dir.join("node_features.npy");
228        npy_writer::write_npy_2d_f64(&path, &features)?;
229
230        Ok(dim)
231    }
232
233    /// Exports edge features.
234    fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
235        let features = graph.edge_features();
236        let dim = features.first().map(|f| f.len()).unwrap_or(0);
237
238        let path = output_dir.join("edge_features.npy");
239        npy_writer::write_npy_2d_f64(&path, &features)?;
240
241        Ok(dim)
242    }
243
244    /// Exports node labels (anomaly flags).
245    fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
246        let labels: Vec<i64> = graph
247            .node_anomaly_mask()
248            .iter()
249            .map(|&b| if b { 1 } else { 0 })
250            .collect();
251
252        let path = output_dir.join("node_labels.npy");
253        npy_writer::write_npy_1d_i64(&path, &labels)?;
254
255        Ok(())
256    }
257
258    /// Exports edge labels (anomaly flags).
259    fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
260        let labels: Vec<i64> = graph
261            .edge_anomaly_mask()
262            .iter()
263            .map(|&b| if b { 1 } else { 0 })
264            .collect();
265
266        let path = output_dir.join("edge_labels.npy");
267        npy_writer::write_npy_1d_i64(&path, &labels)?;
268
269        Ok(())
270    }
271
272    /// Exports train/val/test masks.
273    fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
274        npy_writer::export_masks(
275            output_dir,
276            graph.node_count(),
277            self.config.common.seed,
278            self.config.common.train_ratio,
279            self.config.common.val_ratio,
280        )
281    }
282
283    /// Exports node type indices for heterogeneous graphs.
284    fn export_node_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
285        // Create type mapping
286        let type_to_idx: HashMap<_, _> = graph
287            .nodes_by_type
288            .keys()
289            .enumerate()
290            .map(|(i, t)| (t.clone(), i as i64))
291            .collect();
292
293        // Get node IDs in sorted order for consistent indexing
294        let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
295        node_ids.sort();
296
297        // Map each node to its type index
298        let type_indices: Vec<i64> = node_ids
299            .iter()
300            .map(|id| {
301                let node = graph.nodes.get(id).expect("node ID from keys()");
302                *type_to_idx.get(&node.node_type).unwrap_or(&0)
303            })
304            .collect();
305
306        let path = output_dir.join("node_type_indices.npy");
307        npy_writer::write_npy_1d_i64(&path, &type_indices)?;
308
309        Ok(())
310    }
311
312    /// Exports edge type indices for heterogeneous graphs.
313    fn export_edge_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
314        // Create type mapping
315        let type_to_idx: HashMap<_, _> = graph
316            .edges_by_type
317            .keys()
318            .enumerate()
319            .map(|(i, t)| (t.clone(), i as i64))
320            .collect();
321
322        // Get edge IDs in sorted order for consistent indexing
323        let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
324        edge_ids.sort();
325
326        // Map each edge to its type index
327        let type_indices: Vec<i64> = edge_ids
328            .iter()
329            .map(|id| {
330                let edge = graph.edges.get(id).expect("edge ID from keys()");
331                *type_to_idx.get(&edge.edge_type).unwrap_or(&0)
332            })
333            .collect();
334
335        let path = output_dir.join("edge_type_indices.npy");
336        npy_writer::write_npy_1d_i64(&path, &type_indices)?;
337
338        Ok(())
339    }
340
341    /// Writes a Python loader script for DGL.
342    fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
343        let script = r#"#!/usr/bin/env python3
344"""
345DGL (Deep Graph Library) Data Loader
346
347Auto-generated loader for graph data exported from synth-graph.
348Supports both homogeneous and heterogeneous graph loading.
349"""
350
351import json
352import numpy as np
353from pathlib import Path
354
355try:
356    import torch
357    HAS_TORCH = True
358except ImportError:
359    HAS_TORCH = False
360    print("Warning: torch not installed. Install with: pip install torch")
361
362try:
363    import dgl
364    HAS_DGL = True
365except ImportError:
366    HAS_DGL = False
367    print("Warning: dgl not installed. Install with: pip install dgl")
368
369
370def load_graph(data_dir: str = ".") -> "dgl.DGLGraph":
371    """Load graph data into a DGL graph object.
372
373    Args:
374        data_dir: Directory containing the exported graph data.
375
376    Returns:
377        DGL graph with node features, edge features, and labels attached.
378    """
379    data_dir = Path(data_dir)
380
381    # Load metadata
382    with open(data_dir / "metadata.json") as f:
383        metadata = json.load(f)
384
385    # Load edge index (COO format: [num_edges, 2])
386    edge_index = np.load(data_dir / "edge_index.npy")
387    src = edge_index[:, 0]
388    dst = edge_index[:, 1]
389
390    num_nodes = metadata["num_nodes"]
391
392    if not HAS_DGL:
393        # Return dict if DGL not available
394        result = {
395            "src": src,
396            "dst": dst,
397            "num_nodes": num_nodes,
398            "metadata": metadata,
399        }
400
401        # Load optional arrays
402        if (data_dir / "node_features.npy").exists():
403            result["node_features"] = np.load(data_dir / "node_features.npy")
404        if (data_dir / "edge_features.npy").exists():
405            result["edge_features"] = np.load(data_dir / "edge_features.npy")
406        if (data_dir / "node_labels.npy").exists():
407            result["node_labels"] = np.load(data_dir / "node_labels.npy")
408        if (data_dir / "edge_labels.npy").exists():
409            result["edge_labels"] = np.load(data_dir / "edge_labels.npy")
410        if (data_dir / "train_mask.npy").exists():
411            result["train_mask"] = np.load(data_dir / "train_mask.npy")
412            result["val_mask"] = np.load(data_dir / "val_mask.npy")
413            result["test_mask"] = np.load(data_dir / "test_mask.npy")
414
415        return result
416
417    # Create DGL graph
418    g = dgl.graph((src, dst), num_nodes=num_nodes)
419
420    # Load and attach node features
421    node_features_path = data_dir / "node_features.npy"
422    if node_features_path.exists():
423        node_features = np.load(node_features_path)
424        if HAS_TORCH:
425            g.ndata['feat'] = torch.from_numpy(node_features).float()
426        else:
427            g.ndata['feat'] = node_features
428
429    # Load and attach edge features
430    edge_features_path = data_dir / "edge_features.npy"
431    if edge_features_path.exists():
432        edge_features = np.load(edge_features_path)
433        if HAS_TORCH:
434            g.edata['feat'] = torch.from_numpy(edge_features).float()
435        else:
436            g.edata['feat'] = edge_features
437
438    # Load and attach node labels
439    node_labels_path = data_dir / "node_labels.npy"
440    if node_labels_path.exists():
441        node_labels = np.load(node_labels_path)
442        if HAS_TORCH:
443            g.ndata['label'] = torch.from_numpy(node_labels).long()
444        else:
445            g.ndata['label'] = node_labels
446
447    # Load and attach edge labels
448    edge_labels_path = data_dir / "edge_labels.npy"
449    if edge_labels_path.exists():
450        edge_labels = np.load(edge_labels_path)
451        if HAS_TORCH:
452            g.edata['label'] = torch.from_numpy(edge_labels).long()
453        else:
454            g.edata['label'] = edge_labels
455
456    # Load and attach masks
457    if (data_dir / "train_mask.npy").exists():
458        train_mask = np.load(data_dir / "train_mask.npy")
459        val_mask = np.load(data_dir / "val_mask.npy")
460        test_mask = np.load(data_dir / "test_mask.npy")
461
462        if HAS_TORCH:
463            g.ndata['train_mask'] = torch.from_numpy(train_mask).bool()
464            g.ndata['val_mask'] = torch.from_numpy(val_mask).bool()
465            g.ndata['test_mask'] = torch.from_numpy(test_mask).bool()
466        else:
467            g.ndata['train_mask'] = train_mask
468            g.ndata['val_mask'] = val_mask
469            g.ndata['test_mask'] = test_mask
470
471    # Store metadata as graph attribute
472    g.metadata = metadata
473
474    return g
475
476
477def load_heterogeneous_graph(data_dir: str = ".") -> "dgl.DGLHeteroGraph":
478    """Load graph data into a DGL heterogeneous graph.
479
480    This function handles graphs with multiple node and edge types.
481
482    Args:
483        data_dir: Directory containing the exported graph data.
484
485    Returns:
486        DGL heterogeneous graph.
487    """
488    data_dir = Path(data_dir)
489
490    # Load metadata
491    with open(data_dir / "metadata.json") as f:
492        metadata = json.load(f)
493
494    if not metadata.get("is_heterogeneous", False):
495        print("Warning: Graph was not exported as heterogeneous. Using homogeneous loader.")
496        return load_graph(data_dir)
497
498    if not HAS_DGL:
499        raise ImportError("DGL is required for heterogeneous graph loading")
500
501    # Load edge index and type indices
502    edge_index = np.load(data_dir / "edge_index.npy")
503    edge_types = np.load(data_dir / "edge_type_indices.npy")
504    node_types = np.load(data_dir / "node_type_indices.npy")
505
506    # Get type names from metadata
507    node_type_names = list(metadata["node_types"].keys())
508    edge_type_names = list(metadata["edge_types"].keys())
509
510    # Build edge dict for heterogeneous graph
511    edge_dict = {}
512    for etype_idx, etype_name in enumerate(edge_type_names):
513        mask = edge_types == etype_idx
514        if mask.any():
515            src = edge_index[mask, 0]
516            dst = edge_index[mask, 1]
517            # For heterogeneous, we need to specify (src_type, edge_type, dst_type)
518            # Using simplified convention: (node_type, edge_type, node_type)
519            edge_dict[(node_type_names[0] if node_type_names else 'node',
520                      etype_name,
521                      node_type_names[0] if node_type_names else 'node')] = (src, dst)
522
523    # Create heterogeneous graph
524    g = dgl.heterograph(edge_dict) if edge_dict else dgl.graph(([], []))
525    g.metadata = metadata
526
527    return g
528
529
530def print_summary(data_dir: str = "."):
531    """Print summary of the graph data."""
532    data_dir = Path(data_dir)
533
534    with open(data_dir / "metadata.json") as f:
535        metadata = json.load(f)
536
537    print(f"Graph: {metadata['name']}")
538    print(f"Format: DGL ({metadata['edge_format']} edge format)")
539    print(f"Nodes: {metadata['num_nodes']}")
540    print(f"Edges: {metadata['num_edges']}")
541    print(f"Node feature dim: {metadata['node_feature_dim']}")
542    print(f"Edge feature dim: {metadata['edge_feature_dim']}")
543    print(f"Directed: {metadata['is_directed']}")
544    print(f"Heterogeneous: {metadata['is_heterogeneous']}")
545
546    if metadata['node_types']:
547        print(f"Node types: {metadata['node_types']}")
548    if metadata['edge_types']:
549        print(f"Edge types: {metadata['edge_types']}")
550
551    if metadata['statistics']:
552        print("\nStatistics:")
553        for key, value in metadata['statistics'].items():
554            print(f"  {key}: {value:.4f}")
555
556    if HAS_DGL:
557        print("\nLoading graph...")
558        g = load_graph(data_dir)
559        if hasattr(g, 'num_nodes'):
560            print(f"DGL graph loaded: {g.num_nodes()} nodes, {g.num_edges()} edges")
561            if 'label' in g.ndata:
562                print(f"Anomalous nodes: {g.ndata['label'].sum().item()}")
563
564
565if __name__ == "__main__":
566    import sys
567    data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
568    print_summary(data_dir)
569"#;
570
571        let path = output_dir.join("load_graph.py");
572        let mut file = File::create(path)?;
573        file.write_all(script.as_bytes())?;
574
575        Ok(())
576    }
577
578    /// Writes a helper script for saving/loading DGL graphs as pickle.
579    fn write_pickle_script(&self, output_dir: &Path) -> std::io::Result<()> {
580        let script = r#"#!/usr/bin/env python3
581"""
582DGL Graph Pickle Helper
583
584Utility to save and load DGL graphs as pickle files for faster subsequent loading.
585"""
586
587import pickle
588from pathlib import Path
589
590try:
591    import dgl
592    HAS_DGL = True
593except ImportError:
594    HAS_DGL = False
595
596
597def save_dgl_graph(graph, output_path: str):
598    """Save a DGL graph to a pickle file.
599
600    Args:
601        graph: DGL graph to save.
602        output_path: Path to save the pickle file.
603    """
604    output_path = Path(output_path)
605
606    # Save graph data
607    graph_data = {
608        'num_nodes': graph.num_nodes(),
609        'edges': graph.edges(),
610        'ndata': {k: v.numpy() if hasattr(v, 'numpy') else v
611                  for k, v in graph.ndata.items()},
612        'edata': {k: v.numpy() if hasattr(v, 'numpy') else v
613                  for k, v in graph.edata.items()},
614        'metadata': getattr(graph, 'metadata', {}),
615    }
616
617    with open(output_path, 'wb') as f:
618        pickle.dump(graph_data, f, protocol=pickle.HIGHEST_PROTOCOL)
619
620    print(f"Saved graph to {output_path}")
621
622
623def load_dgl_graph(input_path: str) -> "dgl.DGLGraph":
624    """Load a DGL graph from a pickle file.
625
626    Args:
627        input_path: Path to the pickle file.
628
629    Returns:
630        DGL graph.
631    """
632    if not HAS_DGL:
633        raise ImportError("DGL is required to load graphs")
634
635    import torch
636
637    input_path = Path(input_path)
638
639    with open(input_path, 'rb') as f:
640        graph_data = pickle.load(f)
641
642    # Recreate graph
643    src, dst = graph_data['edges']
644    g = dgl.graph((src, dst), num_nodes=graph_data['num_nodes'])
645
646    # Restore node data
647    for k, v in graph_data['ndata'].items():
648        g.ndata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
649
650    # Restore edge data
651    for k, v in graph_data['edata'].items():
652        g.edata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
653
654    # Restore metadata
655    g.metadata = graph_data.get('metadata', {})
656
657    return g
658
659
660def convert_to_pickle(data_dir: str, output_path: str = None):
661    """Convert exported graph data to pickle format for faster loading.
662
663    Args:
664        data_dir: Directory containing the exported graph data.
665        output_path: Path for output pickle file. Defaults to data_dir/graph.pkl.
666    """
667    from load_graph import load_graph
668
669    data_dir = Path(data_dir)
670    output_path = Path(output_path) if output_path else data_dir / "graph.pkl"
671
672    print(f"Loading graph from {data_dir}...")
673    g = load_graph(str(data_dir))
674
675    if isinstance(g, dict):
676        print("Error: DGL not available, cannot convert to pickle")
677        return
678
679    save_dgl_graph(g, str(output_path))
680    print(f"Graph saved to {output_path}")
681
682
683if __name__ == "__main__":
684    import sys
685
686    if len(sys.argv) < 2:
687        print("Usage:")
688        print("  python pickle_helper.py convert <data_dir> [output_path]")
689        print("  python pickle_helper.py load <pickle_path>")
690        sys.exit(1)
691
692    command = sys.argv[1]
693
694    if command == "convert":
695        data_dir = sys.argv[2] if len(sys.argv) > 2 else "."
696        output_path = sys.argv[3] if len(sys.argv) > 3 else None
697        convert_to_pickle(data_dir, output_path)
698    elif command == "load":
699        pickle_path = sys.argv[2]
700        g = load_dgl_graph(pickle_path)
701        print(f"Loaded graph: {g.num_nodes()} nodes, {g.num_edges()} edges")
702    else:
703        print(f"Unknown command: {command}")
704"#;
705
706        let path = output_dir.join("pickle_helper.py");
707        let mut file = File::create(path)?;
708        file.write_all(script.as_bytes())?;
709
710        Ok(())
711    }
712}
713
714#[cfg(test)]
715#[allow(clippy::unwrap_used)]
716mod tests {
717    use super::*;
718    use crate::test_helpers::create_test_graph_with_company;
719    use tempfile::tempdir;
720
721    #[test]
722    fn test_dgl_export_basic() {
723        let graph = create_test_graph_with_company();
724        let dir = tempdir().unwrap();
725
726        let exporter = DGLExporter::new(DGLExportConfig::default());
727        let metadata = exporter.export(&graph, dir.path()).unwrap();
728
729        assert_eq!(metadata.common.num_nodes, 3);
730        assert_eq!(metadata.common.num_edges, 2);
731        assert_eq!(metadata.edge_format, "COO");
732        assert!(dir.path().join("edge_index.npy").exists());
733        assert!(dir.path().join("node_features.npy").exists());
734        assert!(dir.path().join("node_labels.npy").exists());
735        assert!(dir.path().join("metadata.json").exists());
736        assert!(dir.path().join("load_graph.py").exists());
737        assert!(dir.path().join("pickle_helper.py").exists());
738    }
739
740    #[test]
741    fn test_dgl_export_heterogeneous() {
742        let graph = create_test_graph_with_company();
743        let dir = tempdir().unwrap();
744
745        let config = DGLExportConfig {
746            heterogeneous: true,
747            ..Default::default()
748        };
749        let exporter = DGLExporter::new(config);
750        let metadata = exporter.export(&graph, dir.path()).unwrap();
751
752        assert!(metadata.is_heterogeneous);
753        assert!(dir.path().join("node_type_indices.npy").exists());
754        assert!(dir.path().join("edge_type_indices.npy").exists());
755    }
756
757    #[test]
758    fn test_dgl_export_masks() {
759        let graph = create_test_graph_with_company();
760        let dir = tempdir().unwrap();
761
762        let exporter = DGLExporter::new(DGLExportConfig::default());
763        let metadata = exporter.export(&graph, dir.path()).unwrap();
764
765        assert!(metadata
766            .common
767            .files
768            .contains(&"train_mask.npy".to_string()));
769        assert!(metadata.common.files.contains(&"val_mask.npy".to_string()));
770        assert!(metadata.common.files.contains(&"test_mask.npy".to_string()));
771        assert!(dir.path().join("train_mask.npy").exists());
772        assert!(dir.path().join("val_mask.npy").exists());
773        assert!(dir.path().join("test_mask.npy").exists());
774    }
775
776    #[test]
777    fn test_dgl_coo_format() {
778        let graph = create_test_graph_with_company();
779        let dir = tempdir().unwrap();
780
781        let exporter = DGLExporter::new(DGLExportConfig::default());
782        exporter.export(&graph, dir.path()).unwrap();
783
784        // Verify edge_index file exists and has correct format
785        // COO format should be [num_edges, 2]
786        let edge_path = dir.path().join("edge_index.npy");
787        assert!(edge_path.exists());
788
789        // The metadata confirms format
790        let metadata_path = dir.path().join("metadata.json");
791        let metadata: DGLMetadata =
792            serde_json::from_reader(File::open(metadata_path).unwrap()).unwrap();
793        assert_eq!(metadata.edge_format, "COO");
794    }
795
796    #[test]
797    fn test_dgl_export_no_masks() {
798        let graph = create_test_graph_with_company();
799        let dir = tempdir().unwrap();
800
801        let config = DGLExportConfig {
802            common: CommonExportConfig {
803                export_masks: false,
804                ..Default::default()
805            },
806            ..Default::default()
807        };
808        let exporter = DGLExporter::new(config);
809        let metadata = exporter.export(&graph, dir.path()).unwrap();
810
811        assert!(!metadata
812            .common
813            .files
814            .contains(&"train_mask.npy".to_string()));
815        assert!(!dir.path().join("train_mask.npy").exists());
816    }
817
818    #[test]
819    fn test_dgl_export_minimal() {
820        let graph = create_test_graph_with_company();
821        let dir = tempdir().unwrap();
822
823        let config = DGLExportConfig {
824            common: CommonExportConfig {
825                export_node_features: false,
826                export_edge_features: false,
827                export_node_labels: false,
828                export_edge_labels: false,
829                export_masks: false,
830                ..Default::default()
831            },
832            include_pickle_script: false,
833            ..Default::default()
834        };
835        let exporter = DGLExporter::new(config);
836        let metadata = exporter.export(&graph, dir.path()).unwrap();
837
838        // Only edge_index and loader script should exist
839        assert_eq!(metadata.common.files.len(), 1); // Only edge_index.npy
840        assert!(dir.path().join("edge_index.npy").exists());
841        assert!(dir.path().join("load_graph.py").exists()); // Loader always generated
842        assert!(dir.path().join("metadata.json").exists());
843        assert!(!dir.path().join("pickle_helper.py").exists());
844    }
845
846    #[test]
847    fn test_dgl_statistics() {
848        let graph = create_test_graph_with_company();
849        let dir = tempdir().unwrap();
850
851        let exporter = DGLExporter::new(DGLExportConfig::default());
852        let metadata = exporter.export(&graph, dir.path()).unwrap();
853
854        // Should have density and anomaly ratios
855        assert!(metadata.common.statistics.contains_key("density"));
856        assert!(metadata
857            .common
858            .statistics
859            .contains_key("anomalous_node_ratio"));
860        assert!(metadata
861            .common
862            .statistics
863            .contains_key("anomalous_edge_ratio"));
864
865        // One of three nodes is anomalous
866        let node_ratio = metadata
867            .common
868            .statistics
869            .get("anomalous_node_ratio")
870            .unwrap();
871        assert!((*node_ratio - 1.0 / 3.0).abs() < 0.01);
872    }
873}