1use std::collections::HashMap;
25use std::fs::{self, File};
26use std::io::Write;
27use std::path::Path;
28
29use serde::{Deserialize, Serialize};
30
31use crate::exporters::common::{CommonExportConfig, CommonGraphMetadata};
32use crate::exporters::npy_writer;
33use crate::models::Graph;
34
35#[derive(Debug, Clone)]
37pub struct DGLExportConfig {
38 pub common: CommonExportConfig,
40 pub heterogeneous: bool,
42 pub include_pickle_script: bool,
44}
45
46impl Default for DGLExportConfig {
47 fn default() -> Self {
48 Self {
49 common: CommonExportConfig::default(),
50 heterogeneous: false,
51 include_pickle_script: true,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct DGLMetadata {
59 #[serde(flatten)]
61 pub common: CommonGraphMetadata,
62 pub is_heterogeneous: bool,
64 pub edge_format: String,
66}
67
68pub struct DGLExporter {
70 config: DGLExportConfig,
71}
72
73impl DGLExporter {
74 pub fn new(config: DGLExportConfig) -> Self {
76 Self { config }
77 }
78
79 pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<DGLMetadata> {
81 fs::create_dir_all(output_dir)?;
82
83 let mut files = Vec::new();
84 let mut statistics = HashMap::new();
85
86 self.export_edge_index(graph, output_dir)?;
88 files.push("edge_index.npy".to_string());
89
90 if self.config.common.export_node_features {
92 let dim = self.export_node_features(graph, output_dir)?;
93 files.push("node_features.npy".to_string());
94 statistics.insert("node_feature_dim".to_string(), dim as f64);
95 }
96
97 if self.config.common.export_edge_features {
99 let dim = self.export_edge_features(graph, output_dir)?;
100 files.push("edge_features.npy".to_string());
101 statistics.insert("edge_feature_dim".to_string(), dim as f64);
102 }
103
104 if self.config.common.export_node_labels {
106 self.export_node_labels(graph, output_dir)?;
107 files.push("node_labels.npy".to_string());
108 }
109
110 if self.config.common.export_edge_labels {
112 self.export_edge_labels(graph, output_dir)?;
113 files.push("edge_labels.npy".to_string());
114 }
115
116 if self.config.common.export_masks {
118 self.export_masks(graph, output_dir)?;
119 files.push("train_mask.npy".to_string());
120 files.push("val_mask.npy".to_string());
121 files.push("test_mask.npy".to_string());
122 }
123
124 if self.config.heterogeneous {
126 self.export_node_types(graph, output_dir)?;
127 files.push("node_type_indices.npy".to_string());
128 self.export_edge_types(graph, output_dir)?;
129 files.push("edge_type_indices.npy".to_string());
130 }
131
132 let node_types: HashMap<String, usize> = graph
134 .nodes_by_type
135 .iter()
136 .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
137 .collect();
138
139 let edge_types: HashMap<String, usize> = graph
140 .edges_by_type
141 .iter()
142 .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
143 .collect();
144
145 statistics.insert("density".to_string(), graph.metadata.density);
147 statistics.insert(
148 "anomalous_node_ratio".to_string(),
149 graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
150 );
151 statistics.insert(
152 "anomalous_edge_ratio".to_string(),
153 graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
154 );
155
156 let metadata = DGLMetadata {
158 common: CommonGraphMetadata {
159 name: graph.name.clone(),
160 num_nodes: graph.node_count(),
161 num_edges: graph.edge_count(),
162 node_feature_dim: graph.metadata.node_feature_dim,
163 edge_feature_dim: graph.metadata.edge_feature_dim,
164 num_node_classes: 2, num_edge_classes: 2,
166 node_types,
167 edge_types,
168 is_directed: true,
169 files,
170 statistics,
171 },
172 is_heterogeneous: self.config.heterogeneous,
173 edge_format: "COO".to_string(),
174 };
175
176 let metadata_path = output_dir.join("metadata.json");
178 let file = File::create(metadata_path)?;
179 serde_json::to_writer_pretty(file, &metadata)?;
180
181 self.write_loader_script(output_dir)?;
183
184 if self.config.include_pickle_script {
186 self.write_pickle_script(output_dir)?;
187 }
188
189 Ok(metadata)
190 }
191
192 fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
194 let (sources, targets) = graph.edge_index();
195
196 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
198 node_ids.sort();
199 let id_to_idx: HashMap<_, _> = node_ids
200 .iter()
201 .enumerate()
202 .map(|(i, &id)| (id, i))
203 .collect();
204
205 let num_edges = sources.len();
207 let coo_data: Vec<Vec<i64>> = (0..num_edges)
208 .map(|i| {
209 let src = *id_to_idx.get(&sources[i]).unwrap_or(&0) as i64;
210 let dst = *id_to_idx.get(&targets[i]).unwrap_or(&0) as i64;
211 vec![src, dst]
212 })
213 .collect();
214
215 let path = output_dir.join("edge_index.npy");
217 npy_writer::write_npy_2d_i64(&path, &coo_data)?;
218
219 Ok(())
220 }
221
222 fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
224 let features = graph.node_features();
225 let dim = features.first().map(|f| f.len()).unwrap_or(0);
226
227 let path = output_dir.join("node_features.npy");
228 npy_writer::write_npy_2d_f64(&path, &features)?;
229
230 Ok(dim)
231 }
232
233 fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
235 let features = graph.edge_features();
236 let dim = features.first().map(|f| f.len()).unwrap_or(0);
237
238 let path = output_dir.join("edge_features.npy");
239 npy_writer::write_npy_2d_f64(&path, &features)?;
240
241 Ok(dim)
242 }
243
244 fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
246 let labels: Vec<i64> = graph
247 .node_anomaly_mask()
248 .iter()
249 .map(|&b| if b { 1 } else { 0 })
250 .collect();
251
252 let path = output_dir.join("node_labels.npy");
253 npy_writer::write_npy_1d_i64(&path, &labels)?;
254
255 Ok(())
256 }
257
258 fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
260 let labels: Vec<i64> = graph
261 .edge_anomaly_mask()
262 .iter()
263 .map(|&b| if b { 1 } else { 0 })
264 .collect();
265
266 let path = output_dir.join("edge_labels.npy");
267 npy_writer::write_npy_1d_i64(&path, &labels)?;
268
269 Ok(())
270 }
271
272 fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
274 npy_writer::export_masks(
275 output_dir,
276 graph.node_count(),
277 self.config.common.seed,
278 self.config.common.train_ratio,
279 self.config.common.val_ratio,
280 )
281 }
282
283 fn export_node_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
285 let type_to_idx: HashMap<_, _> = graph
287 .nodes_by_type
288 .keys()
289 .enumerate()
290 .map(|(i, t)| (t.clone(), i as i64))
291 .collect();
292
293 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
295 node_ids.sort();
296
297 let type_indices: Vec<i64> = node_ids
299 .iter()
300 .map(|id| {
301 let node = graph.nodes.get(id).expect("node ID from keys()");
302 *type_to_idx.get(&node.node_type).unwrap_or(&0)
303 })
304 .collect();
305
306 let path = output_dir.join("node_type_indices.npy");
307 npy_writer::write_npy_1d_i64(&path, &type_indices)?;
308
309 Ok(())
310 }
311
312 fn export_edge_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
314 let type_to_idx: HashMap<_, _> = graph
316 .edges_by_type
317 .keys()
318 .enumerate()
319 .map(|(i, t)| (t.clone(), i as i64))
320 .collect();
321
322 let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
324 edge_ids.sort();
325
326 let type_indices: Vec<i64> = edge_ids
328 .iter()
329 .map(|id| {
330 let edge = graph.edges.get(id).expect("edge ID from keys()");
331 *type_to_idx.get(&edge.edge_type).unwrap_or(&0)
332 })
333 .collect();
334
335 let path = output_dir.join("edge_type_indices.npy");
336 npy_writer::write_npy_1d_i64(&path, &type_indices)?;
337
338 Ok(())
339 }
340
341 fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
343 let script = r#"#!/usr/bin/env python3
344"""
345DGL (Deep Graph Library) Data Loader
346
347Auto-generated loader for graph data exported from synth-graph.
348Supports both homogeneous and heterogeneous graph loading.
349"""
350
351import json
352import numpy as np
353from pathlib import Path
354
355try:
356 import torch
357 HAS_TORCH = True
358except ImportError:
359 HAS_TORCH = False
360 print("Warning: torch not installed. Install with: pip install torch")
361
362try:
363 import dgl
364 HAS_DGL = True
365except ImportError:
366 HAS_DGL = False
367 print("Warning: dgl not installed. Install with: pip install dgl")
368
369
370def load_graph(data_dir: str = ".") -> "dgl.DGLGraph":
371 """Load graph data into a DGL graph object.
372
373 Args:
374 data_dir: Directory containing the exported graph data.
375
376 Returns:
377 DGL graph with node features, edge features, and labels attached.
378 """
379 data_dir = Path(data_dir)
380
381 # Load metadata
382 with open(data_dir / "metadata.json") as f:
383 metadata = json.load(f)
384
385 # Load edge index (COO format: [num_edges, 2])
386 edge_index = np.load(data_dir / "edge_index.npy")
387 src = edge_index[:, 0]
388 dst = edge_index[:, 1]
389
390 num_nodes = metadata["num_nodes"]
391
392 if not HAS_DGL:
393 # Return dict if DGL not available
394 result = {
395 "src": src,
396 "dst": dst,
397 "num_nodes": num_nodes,
398 "metadata": metadata,
399 }
400
401 # Load optional arrays
402 if (data_dir / "node_features.npy").exists():
403 result["node_features"] = np.load(data_dir / "node_features.npy")
404 if (data_dir / "edge_features.npy").exists():
405 result["edge_features"] = np.load(data_dir / "edge_features.npy")
406 if (data_dir / "node_labels.npy").exists():
407 result["node_labels"] = np.load(data_dir / "node_labels.npy")
408 if (data_dir / "edge_labels.npy").exists():
409 result["edge_labels"] = np.load(data_dir / "edge_labels.npy")
410 if (data_dir / "train_mask.npy").exists():
411 result["train_mask"] = np.load(data_dir / "train_mask.npy")
412 result["val_mask"] = np.load(data_dir / "val_mask.npy")
413 result["test_mask"] = np.load(data_dir / "test_mask.npy")
414
415 return result
416
417 # Create DGL graph
418 g = dgl.graph((src, dst), num_nodes=num_nodes)
419
420 # Load and attach node features
421 node_features_path = data_dir / "node_features.npy"
422 if node_features_path.exists():
423 node_features = np.load(node_features_path)
424 if HAS_TORCH:
425 g.ndata['feat'] = torch.from_numpy(node_features).float()
426 else:
427 g.ndata['feat'] = node_features
428
429 # Load and attach edge features
430 edge_features_path = data_dir / "edge_features.npy"
431 if edge_features_path.exists():
432 edge_features = np.load(edge_features_path)
433 if HAS_TORCH:
434 g.edata['feat'] = torch.from_numpy(edge_features).float()
435 else:
436 g.edata['feat'] = edge_features
437
438 # Load and attach node labels
439 node_labels_path = data_dir / "node_labels.npy"
440 if node_labels_path.exists():
441 node_labels = np.load(node_labels_path)
442 if HAS_TORCH:
443 g.ndata['label'] = torch.from_numpy(node_labels).long()
444 else:
445 g.ndata['label'] = node_labels
446
447 # Load and attach edge labels
448 edge_labels_path = data_dir / "edge_labels.npy"
449 if edge_labels_path.exists():
450 edge_labels = np.load(edge_labels_path)
451 if HAS_TORCH:
452 g.edata['label'] = torch.from_numpy(edge_labels).long()
453 else:
454 g.edata['label'] = edge_labels
455
456 # Load and attach masks
457 if (data_dir / "train_mask.npy").exists():
458 train_mask = np.load(data_dir / "train_mask.npy")
459 val_mask = np.load(data_dir / "val_mask.npy")
460 test_mask = np.load(data_dir / "test_mask.npy")
461
462 if HAS_TORCH:
463 g.ndata['train_mask'] = torch.from_numpy(train_mask).bool()
464 g.ndata['val_mask'] = torch.from_numpy(val_mask).bool()
465 g.ndata['test_mask'] = torch.from_numpy(test_mask).bool()
466 else:
467 g.ndata['train_mask'] = train_mask
468 g.ndata['val_mask'] = val_mask
469 g.ndata['test_mask'] = test_mask
470
471 # Store metadata as graph attribute
472 g.metadata = metadata
473
474 return g
475
476
477def load_heterogeneous_graph(data_dir: str = ".") -> "dgl.DGLHeteroGraph":
478 """Load graph data into a DGL heterogeneous graph.
479
480 This function handles graphs with multiple node and edge types.
481
482 Args:
483 data_dir: Directory containing the exported graph data.
484
485 Returns:
486 DGL heterogeneous graph.
487 """
488 data_dir = Path(data_dir)
489
490 # Load metadata
491 with open(data_dir / "metadata.json") as f:
492 metadata = json.load(f)
493
494 if not metadata.get("is_heterogeneous", False):
495 print("Warning: Graph was not exported as heterogeneous. Using homogeneous loader.")
496 return load_graph(data_dir)
497
498 if not HAS_DGL:
499 raise ImportError("DGL is required for heterogeneous graph loading")
500
501 # Load edge index and type indices
502 edge_index = np.load(data_dir / "edge_index.npy")
503 edge_types = np.load(data_dir / "edge_type_indices.npy")
504 node_types = np.load(data_dir / "node_type_indices.npy")
505
506 # Get type names from metadata
507 node_type_names = list(metadata["node_types"].keys())
508 edge_type_names = list(metadata["edge_types"].keys())
509
510 # Build edge dict for heterogeneous graph
511 edge_dict = {}
512 for etype_idx, etype_name in enumerate(edge_type_names):
513 mask = edge_types == etype_idx
514 if mask.any():
515 src = edge_index[mask, 0]
516 dst = edge_index[mask, 1]
517 # For heterogeneous, we need to specify (src_type, edge_type, dst_type)
518 # Using simplified convention: (node_type, edge_type, node_type)
519 edge_dict[(node_type_names[0] if node_type_names else 'node',
520 etype_name,
521 node_type_names[0] if node_type_names else 'node')] = (src, dst)
522
523 # Create heterogeneous graph
524 g = dgl.heterograph(edge_dict) if edge_dict else dgl.graph(([], []))
525 g.metadata = metadata
526
527 return g
528
529
530def print_summary(data_dir: str = "."):
531 """Print summary of the graph data."""
532 data_dir = Path(data_dir)
533
534 with open(data_dir / "metadata.json") as f:
535 metadata = json.load(f)
536
537 print(f"Graph: {metadata['name']}")
538 print(f"Format: DGL ({metadata['edge_format']} edge format)")
539 print(f"Nodes: {metadata['num_nodes']}")
540 print(f"Edges: {metadata['num_edges']}")
541 print(f"Node feature dim: {metadata['node_feature_dim']}")
542 print(f"Edge feature dim: {metadata['edge_feature_dim']}")
543 print(f"Directed: {metadata['is_directed']}")
544 print(f"Heterogeneous: {metadata['is_heterogeneous']}")
545
546 if metadata['node_types']:
547 print(f"Node types: {metadata['node_types']}")
548 if metadata['edge_types']:
549 print(f"Edge types: {metadata['edge_types']}")
550
551 if metadata['statistics']:
552 print("\nStatistics:")
553 for key, value in metadata['statistics'].items():
554 print(f" {key}: {value:.4f}")
555
556 if HAS_DGL:
557 print("\nLoading graph...")
558 g = load_graph(data_dir)
559 if hasattr(g, 'num_nodes'):
560 print(f"DGL graph loaded: {g.num_nodes()} nodes, {g.num_edges()} edges")
561 if 'label' in g.ndata:
562 print(f"Anomalous nodes: {g.ndata['label'].sum().item()}")
563
564
565if __name__ == "__main__":
566 import sys
567 data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
568 print_summary(data_dir)
569"#;
570
571 let path = output_dir.join("load_graph.py");
572 let mut file = File::create(path)?;
573 file.write_all(script.as_bytes())?;
574
575 Ok(())
576 }
577
578 fn write_pickle_script(&self, output_dir: &Path) -> std::io::Result<()> {
580 let script = r#"#!/usr/bin/env python3
581"""
582DGL Graph Pickle Helper
583
584Utility to save and load DGL graphs as pickle files for faster subsequent loading.
585"""
586
587import pickle
588from pathlib import Path
589
590try:
591 import dgl
592 HAS_DGL = True
593except ImportError:
594 HAS_DGL = False
595
596
597def save_dgl_graph(graph, output_path: str):
598 """Save a DGL graph to a pickle file.
599
600 Args:
601 graph: DGL graph to save.
602 output_path: Path to save the pickle file.
603 """
604 output_path = Path(output_path)
605
606 # Save graph data
607 graph_data = {
608 'num_nodes': graph.num_nodes(),
609 'edges': graph.edges(),
610 'ndata': {k: v.numpy() if hasattr(v, 'numpy') else v
611 for k, v in graph.ndata.items()},
612 'edata': {k: v.numpy() if hasattr(v, 'numpy') else v
613 for k, v in graph.edata.items()},
614 'metadata': getattr(graph, 'metadata', {}),
615 }
616
617 with open(output_path, 'wb') as f:
618 pickle.dump(graph_data, f, protocol=pickle.HIGHEST_PROTOCOL)
619
620 print(f"Saved graph to {output_path}")
621
622
623def load_dgl_graph(input_path: str) -> "dgl.DGLGraph":
624 """Load a DGL graph from a pickle file.
625
626 Args:
627 input_path: Path to the pickle file.
628
629 Returns:
630 DGL graph.
631 """
632 if not HAS_DGL:
633 raise ImportError("DGL is required to load graphs")
634
635 import torch
636
637 input_path = Path(input_path)
638
639 with open(input_path, 'rb') as f:
640 graph_data = pickle.load(f)
641
642 # Recreate graph
643 src, dst = graph_data['edges']
644 g = dgl.graph((src, dst), num_nodes=graph_data['num_nodes'])
645
646 # Restore node data
647 for k, v in graph_data['ndata'].items():
648 g.ndata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
649
650 # Restore edge data
651 for k, v in graph_data['edata'].items():
652 g.edata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
653
654 # Restore metadata
655 g.metadata = graph_data.get('metadata', {})
656
657 return g
658
659
660def convert_to_pickle(data_dir: str, output_path: str = None):
661 """Convert exported graph data to pickle format for faster loading.
662
663 Args:
664 data_dir: Directory containing the exported graph data.
665 output_path: Path for output pickle file. Defaults to data_dir/graph.pkl.
666 """
667 from load_graph import load_graph
668
669 data_dir = Path(data_dir)
670 output_path = Path(output_path) if output_path else data_dir / "graph.pkl"
671
672 print(f"Loading graph from {data_dir}...")
673 g = load_graph(str(data_dir))
674
675 if isinstance(g, dict):
676 print("Error: DGL not available, cannot convert to pickle")
677 return
678
679 save_dgl_graph(g, str(output_path))
680 print(f"Graph saved to {output_path}")
681
682
683if __name__ == "__main__":
684 import sys
685
686 if len(sys.argv) < 2:
687 print("Usage:")
688 print(" python pickle_helper.py convert <data_dir> [output_path]")
689 print(" python pickle_helper.py load <pickle_path>")
690 sys.exit(1)
691
692 command = sys.argv[1]
693
694 if command == "convert":
695 data_dir = sys.argv[2] if len(sys.argv) > 2 else "."
696 output_path = sys.argv[3] if len(sys.argv) > 3 else None
697 convert_to_pickle(data_dir, output_path)
698 elif command == "load":
699 pickle_path = sys.argv[2]
700 g = load_dgl_graph(pickle_path)
701 print(f"Loaded graph: {g.num_nodes()} nodes, {g.num_edges()} edges")
702 else:
703 print(f"Unknown command: {command}")
704"#;
705
706 let path = output_dir.join("pickle_helper.py");
707 let mut file = File::create(path)?;
708 file.write_all(script.as_bytes())?;
709
710 Ok(())
711 }
712}
713
714#[cfg(test)]
715#[allow(clippy::unwrap_used)]
716mod tests {
717 use super::*;
718 use crate::test_helpers::create_test_graph_with_company;
719 use tempfile::tempdir;
720
721 #[test]
722 fn test_dgl_export_basic() {
723 let graph = create_test_graph_with_company();
724 let dir = tempdir().unwrap();
725
726 let exporter = DGLExporter::new(DGLExportConfig::default());
727 let metadata = exporter.export(&graph, dir.path()).unwrap();
728
729 assert_eq!(metadata.common.num_nodes, 3);
730 assert_eq!(metadata.common.num_edges, 2);
731 assert_eq!(metadata.edge_format, "COO");
732 assert!(dir.path().join("edge_index.npy").exists());
733 assert!(dir.path().join("node_features.npy").exists());
734 assert!(dir.path().join("node_labels.npy").exists());
735 assert!(dir.path().join("metadata.json").exists());
736 assert!(dir.path().join("load_graph.py").exists());
737 assert!(dir.path().join("pickle_helper.py").exists());
738 }
739
740 #[test]
741 fn test_dgl_export_heterogeneous() {
742 let graph = create_test_graph_with_company();
743 let dir = tempdir().unwrap();
744
745 let config = DGLExportConfig {
746 heterogeneous: true,
747 ..Default::default()
748 };
749 let exporter = DGLExporter::new(config);
750 let metadata = exporter.export(&graph, dir.path()).unwrap();
751
752 assert!(metadata.is_heterogeneous);
753 assert!(dir.path().join("node_type_indices.npy").exists());
754 assert!(dir.path().join("edge_type_indices.npy").exists());
755 }
756
757 #[test]
758 fn test_dgl_export_masks() {
759 let graph = create_test_graph_with_company();
760 let dir = tempdir().unwrap();
761
762 let exporter = DGLExporter::new(DGLExportConfig::default());
763 let metadata = exporter.export(&graph, dir.path()).unwrap();
764
765 assert!(metadata
766 .common
767 .files
768 .contains(&"train_mask.npy".to_string()));
769 assert!(metadata.common.files.contains(&"val_mask.npy".to_string()));
770 assert!(metadata.common.files.contains(&"test_mask.npy".to_string()));
771 assert!(dir.path().join("train_mask.npy").exists());
772 assert!(dir.path().join("val_mask.npy").exists());
773 assert!(dir.path().join("test_mask.npy").exists());
774 }
775
776 #[test]
777 fn test_dgl_coo_format() {
778 let graph = create_test_graph_with_company();
779 let dir = tempdir().unwrap();
780
781 let exporter = DGLExporter::new(DGLExportConfig::default());
782 exporter.export(&graph, dir.path()).unwrap();
783
784 let edge_path = dir.path().join("edge_index.npy");
787 assert!(edge_path.exists());
788
789 let metadata_path = dir.path().join("metadata.json");
791 let metadata: DGLMetadata =
792 serde_json::from_reader(File::open(metadata_path).unwrap()).unwrap();
793 assert_eq!(metadata.edge_format, "COO");
794 }
795
796 #[test]
797 fn test_dgl_export_no_masks() {
798 let graph = create_test_graph_with_company();
799 let dir = tempdir().unwrap();
800
801 let config = DGLExportConfig {
802 common: CommonExportConfig {
803 export_masks: false,
804 ..Default::default()
805 },
806 ..Default::default()
807 };
808 let exporter = DGLExporter::new(config);
809 let metadata = exporter.export(&graph, dir.path()).unwrap();
810
811 assert!(!metadata
812 .common
813 .files
814 .contains(&"train_mask.npy".to_string()));
815 assert!(!dir.path().join("train_mask.npy").exists());
816 }
817
818 #[test]
819 fn test_dgl_export_minimal() {
820 let graph = create_test_graph_with_company();
821 let dir = tempdir().unwrap();
822
823 let config = DGLExportConfig {
824 common: CommonExportConfig {
825 export_node_features: false,
826 export_edge_features: false,
827 export_node_labels: false,
828 export_edge_labels: false,
829 export_masks: false,
830 ..Default::default()
831 },
832 include_pickle_script: false,
833 ..Default::default()
834 };
835 let exporter = DGLExporter::new(config);
836 let metadata = exporter.export(&graph, dir.path()).unwrap();
837
838 assert_eq!(metadata.common.files.len(), 1); assert!(dir.path().join("edge_index.npy").exists());
841 assert!(dir.path().join("load_graph.py").exists()); assert!(dir.path().join("metadata.json").exists());
843 assert!(!dir.path().join("pickle_helper.py").exists());
844 }
845
846 #[test]
847 fn test_dgl_statistics() {
848 let graph = create_test_graph_with_company();
849 let dir = tempdir().unwrap();
850
851 let exporter = DGLExporter::new(DGLExportConfig::default());
852 let metadata = exporter.export(&graph, dir.path()).unwrap();
853
854 assert!(metadata.common.statistics.contains_key("density"));
856 assert!(metadata
857 .common
858 .statistics
859 .contains_key("anomalous_node_ratio"));
860 assert!(metadata
861 .common
862 .statistics
863 .contains_key("anomalous_edge_ratio"));
864
865 let node_ratio = metadata
867 .common
868 .statistics
869 .get("anomalous_node_ratio")
870 .unwrap();
871 assert!((*node_ratio - 1.0 / 3.0).abs() < 0.01);
872 }
873}