1use std::collections::HashMap;
25use std::fs::{self, File};
26use std::io::Write;
27use std::path::Path;
28
29use serde::{Deserialize, Serialize};
30
31use crate::exporters::common::{CommonExportConfig, CommonGraphMetadata};
32use crate::exporters::npy_writer;
33use crate::models::Graph;
34
35#[derive(Debug, Clone)]
37pub struct DGLExportConfig {
38 pub common: CommonExportConfig,
40 pub heterogeneous: bool,
42 pub include_pickle_script: bool,
44}
45
46impl Default for DGLExportConfig {
47 fn default() -> Self {
48 Self {
49 common: CommonExportConfig::default(),
50 heterogeneous: false,
51 include_pickle_script: true,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct DGLMetadata {
59 #[serde(flatten)]
61 pub common: CommonGraphMetadata,
62 pub is_heterogeneous: bool,
64 pub edge_format: String,
66}
67
68pub struct DGLExporter {
70 config: DGLExportConfig,
71}
72
73impl DGLExporter {
74 pub fn new(config: DGLExportConfig) -> Self {
76 Self { config }
77 }
78
79 pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<DGLMetadata> {
81 fs::create_dir_all(output_dir)?;
82
83 let mut files = Vec::new();
84 let mut statistics = HashMap::new();
85
86 self.export_edge_index(graph, output_dir)?;
88 files.push("edge_index.npy".to_string());
89
90 if self.config.common.export_node_features {
92 let dim = self.export_node_features(graph, output_dir)?;
93 files.push("node_features.npy".to_string());
94 statistics.insert("node_feature_dim".to_string(), dim as f64);
95 }
96
97 if self.config.common.export_edge_features {
99 let dim = self.export_edge_features(graph, output_dir)?;
100 files.push("edge_features.npy".to_string());
101 statistics.insert("edge_feature_dim".to_string(), dim as f64);
102 }
103
104 if self.config.common.export_node_labels {
106 self.export_node_labels(graph, output_dir)?;
107 files.push("node_labels.npy".to_string());
108 }
109
110 if self.config.common.export_edge_labels {
112 self.export_edge_labels(graph, output_dir)?;
113 files.push("edge_labels.npy".to_string());
114 }
115
116 if self.config.common.export_masks {
118 self.export_masks(graph, output_dir)?;
119 files.push("train_mask.npy".to_string());
120 files.push("val_mask.npy".to_string());
121 files.push("test_mask.npy".to_string());
122 }
123
124 if self.config.heterogeneous {
126 self.export_node_types(graph, output_dir)?;
127 files.push("node_type_indices.npy".to_string());
128 self.export_edge_types(graph, output_dir)?;
129 files.push("edge_type_indices.npy".to_string());
130 }
131
132 let node_types: HashMap<String, usize> = graph
134 .nodes_by_type
135 .iter()
136 .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
137 .collect();
138
139 let edge_types: HashMap<String, usize> = graph
140 .edges_by_type
141 .iter()
142 .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
143 .collect();
144
145 statistics.insert("density".to_string(), graph.metadata.density);
147 statistics.insert(
148 "anomalous_node_ratio".to_string(),
149 graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
150 );
151 statistics.insert(
152 "anomalous_edge_ratio".to_string(),
153 graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
154 );
155
156 let metadata = DGLMetadata {
158 common: CommonGraphMetadata {
159 name: graph.name.clone(),
160 num_nodes: graph.node_count(),
161 num_edges: graph.edge_count(),
162 node_feature_dim: graph.metadata.node_feature_dim,
163 edge_feature_dim: graph.metadata.edge_feature_dim,
164 num_node_classes: 2, num_edge_classes: 2,
166 node_types,
167 edge_types,
168 is_directed: true,
169 files,
170 statistics,
171 },
172 is_heterogeneous: self.config.heterogeneous,
173 edge_format: "COO".to_string(),
174 };
175
176 let metadata_path = output_dir.join("metadata.json");
178 let file = File::create(metadata_path)?;
179 serde_json::to_writer_pretty(file, &metadata)?;
180
181 self.write_loader_script(output_dir)?;
183
184 if self.config.include_pickle_script {
186 self.write_pickle_script(output_dir)?;
187 }
188
189 Ok(metadata)
190 }
191
192 fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
194 let (sources, targets) = graph.edge_index();
195
196 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
198 node_ids.sort();
199 let id_to_idx: HashMap<_, _> = node_ids
200 .iter()
201 .enumerate()
202 .map(|(i, &id)| (id, i))
203 .collect();
204
205 let num_edges = sources.len();
207 let mut coo_data: Vec<Vec<i64>> = Vec::with_capacity(num_edges);
208 let mut skipped_edges = 0usize;
209
210 for i in 0..num_edges {
211 match (id_to_idx.get(&sources[i]), id_to_idx.get(&targets[i])) {
212 (Some(&s), Some(&d)) => {
213 coo_data.push(vec![s as i64, d as i64]);
214 }
215 _ => {
216 skipped_edges += 1;
217 }
218 }
219 }
220 if skipped_edges > 0 {
221 tracing::warn!(
222 "DGL export: skipped {} edges with missing node IDs",
223 skipped_edges
224 );
225 }
226
227 let path = output_dir.join("edge_index.npy");
229 npy_writer::write_npy_2d_i64(&path, &coo_data)?;
230
231 Ok(())
232 }
233
234 fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
236 let features = graph.node_features();
237 let dim = features.first().map(|f| f.len()).unwrap_or(0);
238
239 let path = output_dir.join("node_features.npy");
240 npy_writer::write_npy_2d_f64(&path, &features)?;
241
242 Ok(dim)
243 }
244
245 fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
247 let features = graph.edge_features();
248 let dim = features.first().map(|f| f.len()).unwrap_or(0);
249
250 let path = output_dir.join("edge_features.npy");
251 npy_writer::write_npy_2d_f64(&path, &features)?;
252
253 Ok(dim)
254 }
255
256 fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
258 let labels: Vec<i64> = graph
259 .node_anomaly_mask()
260 .iter()
261 .map(|&b| if b { 1 } else { 0 })
262 .collect();
263
264 let path = output_dir.join("node_labels.npy");
265 npy_writer::write_npy_1d_i64(&path, &labels)?;
266
267 Ok(())
268 }
269
270 fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
272 let labels: Vec<i64> = graph
273 .edge_anomaly_mask()
274 .iter()
275 .map(|&b| if b { 1 } else { 0 })
276 .collect();
277
278 let path = output_dir.join("edge_labels.npy");
279 npy_writer::write_npy_1d_i64(&path, &labels)?;
280
281 Ok(())
282 }
283
284 fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
286 npy_writer::export_masks(
287 output_dir,
288 graph.node_count(),
289 self.config.common.seed,
290 self.config.common.train_ratio,
291 self.config.common.val_ratio,
292 )
293 }
294
295 fn export_node_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
297 let type_to_idx: HashMap<_, _> = graph
299 .nodes_by_type
300 .keys()
301 .enumerate()
302 .map(|(i, t)| (t.clone(), i as i64))
303 .collect();
304
305 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
307 node_ids.sort();
308
309 let type_indices: Vec<i64> = node_ids
311 .iter()
312 .map(|id| {
313 let node = graph.nodes.get(id).expect("node ID from keys()");
314 *type_to_idx.get(&node.node_type).unwrap_or_else(|| {
315 tracing::warn!(
316 "Unknown node type '{:?}', defaulting to index 0",
317 node.node_type
318 );
319 &0
320 })
321 })
322 .collect();
323
324 let path = output_dir.join("node_type_indices.npy");
325 npy_writer::write_npy_1d_i64(&path, &type_indices)?;
326
327 Ok(())
328 }
329
330 fn export_edge_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
332 let type_to_idx: HashMap<_, _> = graph
334 .edges_by_type
335 .keys()
336 .enumerate()
337 .map(|(i, t)| (t.clone(), i as i64))
338 .collect();
339
340 let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
342 edge_ids.sort();
343
344 let type_indices: Vec<i64> = edge_ids
346 .iter()
347 .map(|id| {
348 let edge = graph.edges.get(id).expect("edge ID from keys()");
349 *type_to_idx.get(&edge.edge_type).unwrap_or_else(|| {
350 tracing::warn!(
351 "Unknown edge type '{:?}', defaulting to index 0",
352 edge.edge_type
353 );
354 &0
355 })
356 })
357 .collect();
358
359 let path = output_dir.join("edge_type_indices.npy");
360 npy_writer::write_npy_1d_i64(&path, &type_indices)?;
361
362 Ok(())
363 }
364
365 fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
367 let script = r#"#!/usr/bin/env python3
368"""
369DGL (Deep Graph Library) Data Loader
370
371Auto-generated loader for graph data exported from synth-graph.
372Supports both homogeneous and heterogeneous graph loading.
373"""
374
375import json
376import numpy as np
377from pathlib import Path
378
379try:
380 import torch
381 HAS_TORCH = True
382except ImportError:
383 HAS_TORCH = False
384 print("Warning: torch not installed. Install with: pip install torch")
385
386try:
387 import dgl
388 HAS_DGL = True
389except ImportError:
390 HAS_DGL = False
391 print("Warning: dgl not installed. Install with: pip install dgl")
392
393
394def load_graph(data_dir: str = ".") -> "dgl.DGLGraph":
395 """Load graph data into a DGL graph object.
396
397 Args:
398 data_dir: Directory containing the exported graph data.
399
400 Returns:
401 DGL graph with node features, edge features, and labels attached.
402 """
403 data_dir = Path(data_dir)
404
405 # Load metadata
406 with open(data_dir / "metadata.json") as f:
407 metadata = json.load(f)
408
409 # Load edge index (COO format: [num_edges, 2])
410 edge_index = np.load(data_dir / "edge_index.npy")
411 src = edge_index[:, 0]
412 dst = edge_index[:, 1]
413
414 num_nodes = metadata["num_nodes"]
415
416 if not HAS_DGL:
417 # Return dict if DGL not available
418 result = {
419 "src": src,
420 "dst": dst,
421 "num_nodes": num_nodes,
422 "metadata": metadata,
423 }
424
425 # Load optional arrays
426 if (data_dir / "node_features.npy").exists():
427 result["node_features"] = np.load(data_dir / "node_features.npy")
428 if (data_dir / "edge_features.npy").exists():
429 result["edge_features"] = np.load(data_dir / "edge_features.npy")
430 if (data_dir / "node_labels.npy").exists():
431 result["node_labels"] = np.load(data_dir / "node_labels.npy")
432 if (data_dir / "edge_labels.npy").exists():
433 result["edge_labels"] = np.load(data_dir / "edge_labels.npy")
434 if (data_dir / "train_mask.npy").exists():
435 result["train_mask"] = np.load(data_dir / "train_mask.npy")
436 result["val_mask"] = np.load(data_dir / "val_mask.npy")
437 result["test_mask"] = np.load(data_dir / "test_mask.npy")
438
439 return result
440
441 # Create DGL graph
442 g = dgl.graph((src, dst), num_nodes=num_nodes)
443
444 # Load and attach node features
445 node_features_path = data_dir / "node_features.npy"
446 if node_features_path.exists():
447 node_features = np.load(node_features_path)
448 if HAS_TORCH:
449 g.ndata['feat'] = torch.from_numpy(node_features).float()
450 else:
451 g.ndata['feat'] = node_features
452
453 # Load and attach edge features
454 edge_features_path = data_dir / "edge_features.npy"
455 if edge_features_path.exists():
456 edge_features = np.load(edge_features_path)
457 if HAS_TORCH:
458 g.edata['feat'] = torch.from_numpy(edge_features).float()
459 else:
460 g.edata['feat'] = edge_features
461
462 # Load and attach node labels
463 node_labels_path = data_dir / "node_labels.npy"
464 if node_labels_path.exists():
465 node_labels = np.load(node_labels_path)
466 if HAS_TORCH:
467 g.ndata['label'] = torch.from_numpy(node_labels).long()
468 else:
469 g.ndata['label'] = node_labels
470
471 # Load and attach edge labels
472 edge_labels_path = data_dir / "edge_labels.npy"
473 if edge_labels_path.exists():
474 edge_labels = np.load(edge_labels_path)
475 if HAS_TORCH:
476 g.edata['label'] = torch.from_numpy(edge_labels).long()
477 else:
478 g.edata['label'] = edge_labels
479
480 # Load and attach masks
481 if (data_dir / "train_mask.npy").exists():
482 train_mask = np.load(data_dir / "train_mask.npy")
483 val_mask = np.load(data_dir / "val_mask.npy")
484 test_mask = np.load(data_dir / "test_mask.npy")
485
486 if HAS_TORCH:
487 g.ndata['train_mask'] = torch.from_numpy(train_mask).bool()
488 g.ndata['val_mask'] = torch.from_numpy(val_mask).bool()
489 g.ndata['test_mask'] = torch.from_numpy(test_mask).bool()
490 else:
491 g.ndata['train_mask'] = train_mask
492 g.ndata['val_mask'] = val_mask
493 g.ndata['test_mask'] = test_mask
494
495 # Store metadata as graph attribute
496 g.metadata = metadata
497
498 return g
499
500
501def load_heterogeneous_graph(data_dir: str = ".") -> "dgl.DGLHeteroGraph":
502 """Load graph data into a DGL heterogeneous graph.
503
504 This function handles graphs with multiple node and edge types.
505
506 Args:
507 data_dir: Directory containing the exported graph data.
508
509 Returns:
510 DGL heterogeneous graph.
511 """
512 data_dir = Path(data_dir)
513
514 # Load metadata
515 with open(data_dir / "metadata.json") as f:
516 metadata = json.load(f)
517
518 if not metadata.get("is_heterogeneous", False):
519 print("Warning: Graph was not exported as heterogeneous. Using homogeneous loader.")
520 return load_graph(data_dir)
521
522 if not HAS_DGL:
523 raise ImportError("DGL is required for heterogeneous graph loading")
524
525 # Load edge index and type indices
526 edge_index = np.load(data_dir / "edge_index.npy")
527 edge_types = np.load(data_dir / "edge_type_indices.npy")
528 node_types = np.load(data_dir / "node_type_indices.npy")
529
530 # Get type names from metadata
531 node_type_names = list(metadata["node_types"].keys())
532 edge_type_names = list(metadata["edge_types"].keys())
533
534 # Build edge dict for heterogeneous graph
535 edge_dict = {}
536 for etype_idx, etype_name in enumerate(edge_type_names):
537 mask = edge_types == etype_idx
538 if mask.any():
539 src = edge_index[mask, 0]
540 dst = edge_index[mask, 1]
541 # For heterogeneous, we need to specify (src_type, edge_type, dst_type)
542 # Using simplified convention: (node_type, edge_type, node_type)
543 edge_dict[(node_type_names[0] if node_type_names else 'node',
544 etype_name,
545 node_type_names[0] if node_type_names else 'node')] = (src, dst)
546
547 # Create heterogeneous graph
548 g = dgl.heterograph(edge_dict) if edge_dict else dgl.graph(([], []))
549 g.metadata = metadata
550
551 return g
552
553
554def print_summary(data_dir: str = "."):
555 """Print summary of the graph data."""
556 data_dir = Path(data_dir)
557
558 with open(data_dir / "metadata.json") as f:
559 metadata = json.load(f)
560
561 print(f"Graph: {metadata['name']}")
562 print(f"Format: DGL ({metadata['edge_format']} edge format)")
563 print(f"Nodes: {metadata['num_nodes']}")
564 print(f"Edges: {metadata['num_edges']}")
565 print(f"Node feature dim: {metadata['node_feature_dim']}")
566 print(f"Edge feature dim: {metadata['edge_feature_dim']}")
567 print(f"Directed: {metadata['is_directed']}")
568 print(f"Heterogeneous: {metadata['is_heterogeneous']}")
569
570 if metadata['node_types']:
571 print(f"Node types: {metadata['node_types']}")
572 if metadata['edge_types']:
573 print(f"Edge types: {metadata['edge_types']}")
574
575 if metadata['statistics']:
576 print("\nStatistics:")
577 for key, value in metadata['statistics'].items():
578 print(f" {key}: {value:.4f}")
579
580 if HAS_DGL:
581 print("\nLoading graph...")
582 g = load_graph(data_dir)
583 if hasattr(g, 'num_nodes'):
584 print(f"DGL graph loaded: {g.num_nodes()} nodes, {g.num_edges()} edges")
585 if 'label' in g.ndata:
586 print(f"Anomalous nodes: {g.ndata['label'].sum().item()}")
587
588
589if __name__ == "__main__":
590 import sys
591 data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
592 print_summary(data_dir)
593"#;
594
595 let path = output_dir.join("load_graph.py");
596 let mut file = File::create(path)?;
597 file.write_all(script.as_bytes())?;
598
599 Ok(())
600 }
601
602 fn write_pickle_script(&self, output_dir: &Path) -> std::io::Result<()> {
604 let script = r#"#!/usr/bin/env python3
605"""
606DGL Graph Pickle Helper
607
608Utility to save and load DGL graphs as pickle files for faster subsequent loading.
609"""
610
611import pickle
612from pathlib import Path
613
614try:
615 import dgl
616 HAS_DGL = True
617except ImportError:
618 HAS_DGL = False
619
620
621def save_dgl_graph(graph, output_path: str):
622 """Save a DGL graph to a pickle file.
623
624 Args:
625 graph: DGL graph to save.
626 output_path: Path to save the pickle file.
627 """
628 output_path = Path(output_path)
629
630 # Save graph data
631 graph_data = {
632 'num_nodes': graph.num_nodes(),
633 'edges': graph.edges(),
634 'ndata': {k: v.numpy() if hasattr(v, 'numpy') else v
635 for k, v in graph.ndata.items()},
636 'edata': {k: v.numpy() if hasattr(v, 'numpy') else v
637 for k, v in graph.edata.items()},
638 'metadata': getattr(graph, 'metadata', {}),
639 }
640
641 with open(output_path, 'wb') as f:
642 pickle.dump(graph_data, f, protocol=pickle.HIGHEST_PROTOCOL)
643
644 print(f"Saved graph to {output_path}")
645
646
647def load_dgl_graph(input_path: str) -> "dgl.DGLGraph":
648 """Load a DGL graph from a pickle file.
649
650 Args:
651 input_path: Path to the pickle file.
652
653 Returns:
654 DGL graph.
655 """
656 if not HAS_DGL:
657 raise ImportError("DGL is required to load graphs")
658
659 import torch
660
661 input_path = Path(input_path)
662
663 with open(input_path, 'rb') as f:
664 graph_data = pickle.load(f)
665
666 # Recreate graph
667 src, dst = graph_data['edges']
668 g = dgl.graph((src, dst), num_nodes=graph_data['num_nodes'])
669
670 # Restore node data
671 for k, v in graph_data['ndata'].items():
672 g.ndata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
673
674 # Restore edge data
675 for k, v in graph_data['edata'].items():
676 g.edata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
677
678 # Restore metadata
679 g.metadata = graph_data.get('metadata', {})
680
681 return g
682
683
684def convert_to_pickle(data_dir: str, output_path: str = None):
685 """Convert exported graph data to pickle format for faster loading.
686
687 Args:
688 data_dir: Directory containing the exported graph data.
689 output_path: Path for output pickle file. Defaults to data_dir/graph.pkl.
690 """
691 from load_graph import load_graph
692
693 data_dir = Path(data_dir)
694 output_path = Path(output_path) if output_path else data_dir / "graph.pkl"
695
696 print(f"Loading graph from {data_dir}...")
697 g = load_graph(str(data_dir))
698
699 if isinstance(g, dict):
700 print("Error: DGL not available, cannot convert to pickle")
701 return
702
703 save_dgl_graph(g, str(output_path))
704 print(f"Graph saved to {output_path}")
705
706
707if __name__ == "__main__":
708 import sys
709
710 if len(sys.argv) < 2:
711 print("Usage:")
712 print(" python pickle_helper.py convert <data_dir> [output_path]")
713 print(" python pickle_helper.py load <pickle_path>")
714 sys.exit(1)
715
716 command = sys.argv[1]
717
718 if command == "convert":
719 data_dir = sys.argv[2] if len(sys.argv) > 2 else "."
720 output_path = sys.argv[3] if len(sys.argv) > 3 else None
721 convert_to_pickle(data_dir, output_path)
722 elif command == "load":
723 pickle_path = sys.argv[2]
724 g = load_dgl_graph(pickle_path)
725 print(f"Loaded graph: {g.num_nodes()} nodes, {g.num_edges()} edges")
726 else:
727 print(f"Unknown command: {command}")
728"#;
729
730 let path = output_dir.join("pickle_helper.py");
731 let mut file = File::create(path)?;
732 file.write_all(script.as_bytes())?;
733
734 Ok(())
735 }
736}
737
738#[cfg(test)]
739#[allow(clippy::unwrap_used)]
740mod tests {
741 use super::*;
742 use crate::test_helpers::create_test_graph_with_company;
743 use tempfile::tempdir;
744
745 #[test]
746 fn test_dgl_export_basic() {
747 let graph = create_test_graph_with_company();
748 let dir = tempdir().unwrap();
749
750 let exporter = DGLExporter::new(DGLExportConfig::default());
751 let metadata = exporter.export(&graph, dir.path()).unwrap();
752
753 assert_eq!(metadata.common.num_nodes, 3);
754 assert_eq!(metadata.common.num_edges, 2);
755 assert_eq!(metadata.edge_format, "COO");
756 assert!(dir.path().join("edge_index.npy").exists());
757 assert!(dir.path().join("node_features.npy").exists());
758 assert!(dir.path().join("node_labels.npy").exists());
759 assert!(dir.path().join("metadata.json").exists());
760 assert!(dir.path().join("load_graph.py").exists());
761 assert!(dir.path().join("pickle_helper.py").exists());
762 }
763
764 #[test]
765 fn test_dgl_export_heterogeneous() {
766 let graph = create_test_graph_with_company();
767 let dir = tempdir().unwrap();
768
769 let config = DGLExportConfig {
770 heterogeneous: true,
771 ..Default::default()
772 };
773 let exporter = DGLExporter::new(config);
774 let metadata = exporter.export(&graph, dir.path()).unwrap();
775
776 assert!(metadata.is_heterogeneous);
777 assert!(dir.path().join("node_type_indices.npy").exists());
778 assert!(dir.path().join("edge_type_indices.npy").exists());
779 }
780
781 #[test]
782 fn test_dgl_export_masks() {
783 let graph = create_test_graph_with_company();
784 let dir = tempdir().unwrap();
785
786 let exporter = DGLExporter::new(DGLExportConfig::default());
787 let metadata = exporter.export(&graph, dir.path()).unwrap();
788
789 assert!(metadata
790 .common
791 .files
792 .contains(&"train_mask.npy".to_string()));
793 assert!(metadata.common.files.contains(&"val_mask.npy".to_string()));
794 assert!(metadata.common.files.contains(&"test_mask.npy".to_string()));
795 assert!(dir.path().join("train_mask.npy").exists());
796 assert!(dir.path().join("val_mask.npy").exists());
797 assert!(dir.path().join("test_mask.npy").exists());
798 }
799
800 #[test]
801 fn test_dgl_coo_format() {
802 let graph = create_test_graph_with_company();
803 let dir = tempdir().unwrap();
804
805 let exporter = DGLExporter::new(DGLExportConfig::default());
806 exporter.export(&graph, dir.path()).unwrap();
807
808 let edge_path = dir.path().join("edge_index.npy");
811 assert!(edge_path.exists());
812
813 let metadata_path = dir.path().join("metadata.json");
815 let metadata: DGLMetadata =
816 serde_json::from_reader(File::open(metadata_path).unwrap()).unwrap();
817 assert_eq!(metadata.edge_format, "COO");
818 }
819
820 #[test]
821 fn test_dgl_export_no_masks() {
822 let graph = create_test_graph_with_company();
823 let dir = tempdir().unwrap();
824
825 let config = DGLExportConfig {
826 common: CommonExportConfig {
827 export_masks: false,
828 ..Default::default()
829 },
830 ..Default::default()
831 };
832 let exporter = DGLExporter::new(config);
833 let metadata = exporter.export(&graph, dir.path()).unwrap();
834
835 assert!(!metadata
836 .common
837 .files
838 .contains(&"train_mask.npy".to_string()));
839 assert!(!dir.path().join("train_mask.npy").exists());
840 }
841
842 #[test]
843 fn test_dgl_export_minimal() {
844 let graph = create_test_graph_with_company();
845 let dir = tempdir().unwrap();
846
847 let config = DGLExportConfig {
848 common: CommonExportConfig {
849 export_node_features: false,
850 export_edge_features: false,
851 export_node_labels: false,
852 export_edge_labels: false,
853 export_masks: false,
854 ..Default::default()
855 },
856 include_pickle_script: false,
857 ..Default::default()
858 };
859 let exporter = DGLExporter::new(config);
860 let metadata = exporter.export(&graph, dir.path()).unwrap();
861
862 assert_eq!(metadata.common.files.len(), 1); assert!(dir.path().join("edge_index.npy").exists());
865 assert!(dir.path().join("load_graph.py").exists()); assert!(dir.path().join("metadata.json").exists());
867 assert!(!dir.path().join("pickle_helper.py").exists());
868 }
869
870 #[test]
871 fn test_dgl_statistics() {
872 let graph = create_test_graph_with_company();
873 let dir = tempdir().unwrap();
874
875 let exporter = DGLExporter::new(DGLExportConfig::default());
876 let metadata = exporter.export(&graph, dir.path()).unwrap();
877
878 assert!(metadata.common.statistics.contains_key("density"));
880 assert!(metadata
881 .common
882 .statistics
883 .contains_key("anomalous_node_ratio"));
884 assert!(metadata
885 .common
886 .statistics
887 .contains_key("anomalous_edge_ratio"));
888
889 let node_ratio = metadata
891 .common
892 .statistics
893 .get("anomalous_node_ratio")
894 .unwrap();
895 assert!((*node_ratio - 1.0 / 3.0).abs() < 0.01);
896 }
897}