1use std::collections::HashMap;
25use std::fs::{self, File};
26use std::io::{BufWriter, Write};
27use std::path::Path;
28
29use serde::{Deserialize, Serialize};
30
31use crate::exporters::common::{CommonExportConfig, CommonGraphMetadata};
32use crate::models::Graph;
33
34#[derive(Debug, Clone)]
36pub struct DGLExportConfig {
37 pub common: CommonExportConfig,
39 pub heterogeneous: bool,
41 pub include_pickle_script: bool,
43}
44
45impl Default for DGLExportConfig {
46 fn default() -> Self {
47 Self {
48 common: CommonExportConfig::default(),
49 heterogeneous: false,
50 include_pickle_script: true,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct DGLMetadata {
58 #[serde(flatten)]
60 pub common: CommonGraphMetadata,
61 pub is_heterogeneous: bool,
63 pub edge_format: String,
65}
66
67pub struct DGLExporter {
69 config: DGLExportConfig,
70}
71
72impl DGLExporter {
73 pub fn new(config: DGLExportConfig) -> Self {
75 Self { config }
76 }
77
78 pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<DGLMetadata> {
80 fs::create_dir_all(output_dir)?;
81
82 let mut files = Vec::new();
83 let mut statistics = HashMap::new();
84
85 self.export_edge_index(graph, output_dir)?;
87 files.push("edge_index.npy".to_string());
88
89 if self.config.common.export_node_features {
91 let dim = self.export_node_features(graph, output_dir)?;
92 files.push("node_features.npy".to_string());
93 statistics.insert("node_feature_dim".to_string(), dim as f64);
94 }
95
96 if self.config.common.export_edge_features {
98 let dim = self.export_edge_features(graph, output_dir)?;
99 files.push("edge_features.npy".to_string());
100 statistics.insert("edge_feature_dim".to_string(), dim as f64);
101 }
102
103 if self.config.common.export_node_labels {
105 self.export_node_labels(graph, output_dir)?;
106 files.push("node_labels.npy".to_string());
107 }
108
109 if self.config.common.export_edge_labels {
111 self.export_edge_labels(graph, output_dir)?;
112 files.push("edge_labels.npy".to_string());
113 }
114
115 if self.config.common.export_masks {
117 self.export_masks(graph, output_dir)?;
118 files.push("train_mask.npy".to_string());
119 files.push("val_mask.npy".to_string());
120 files.push("test_mask.npy".to_string());
121 }
122
123 if self.config.heterogeneous {
125 self.export_node_types(graph, output_dir)?;
126 files.push("node_type_indices.npy".to_string());
127 self.export_edge_types(graph, output_dir)?;
128 files.push("edge_type_indices.npy".to_string());
129 }
130
131 let node_types: HashMap<String, usize> = graph
133 .nodes_by_type
134 .iter()
135 .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
136 .collect();
137
138 let edge_types: HashMap<String, usize> = graph
139 .edges_by_type
140 .iter()
141 .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
142 .collect();
143
144 statistics.insert("density".to_string(), graph.metadata.density);
146 statistics.insert(
147 "anomalous_node_ratio".to_string(),
148 graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
149 );
150 statistics.insert(
151 "anomalous_edge_ratio".to_string(),
152 graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
153 );
154
155 let metadata = DGLMetadata {
157 common: CommonGraphMetadata {
158 name: graph.name.clone(),
159 num_nodes: graph.node_count(),
160 num_edges: graph.edge_count(),
161 node_feature_dim: graph.metadata.node_feature_dim,
162 edge_feature_dim: graph.metadata.edge_feature_dim,
163 num_node_classes: 2, num_edge_classes: 2,
165 node_types,
166 edge_types,
167 is_directed: true,
168 files,
169 statistics,
170 },
171 is_heterogeneous: self.config.heterogeneous,
172 edge_format: "COO".to_string(),
173 };
174
175 let metadata_path = output_dir.join("metadata.json");
177 let file = File::create(metadata_path)?;
178 serde_json::to_writer_pretty(file, &metadata)?;
179
180 self.write_loader_script(output_dir)?;
182
183 if self.config.include_pickle_script {
185 self.write_pickle_script(output_dir)?;
186 }
187
188 Ok(metadata)
189 }
190
191 fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
193 let (sources, targets) = graph.edge_index();
194
195 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
197 node_ids.sort();
198 let id_to_idx: HashMap<_, _> = node_ids
199 .iter()
200 .enumerate()
201 .map(|(i, &id)| (id, i))
202 .collect();
203
204 let num_edges = sources.len();
206 let coo_data: Vec<Vec<i64>> = (0..num_edges)
207 .map(|i| {
208 let src = *id_to_idx.get(&sources[i]).unwrap_or(&0) as i64;
209 let dst = *id_to_idx.get(&targets[i]).unwrap_or(&0) as i64;
210 vec![src, dst]
211 })
212 .collect();
213
214 let path = output_dir.join("edge_index.npy");
216 self.write_npy_2d_i64(&path, &coo_data)?;
217
218 Ok(())
219 }
220
221 fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
223 let features = graph.node_features();
224 let dim = features.first().map(|f| f.len()).unwrap_or(0);
225
226 let path = output_dir.join("node_features.npy");
227 self.write_npy_2d_f64(&path, &features)?;
228
229 Ok(dim)
230 }
231
232 fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
234 let features = graph.edge_features();
235 let dim = features.first().map(|f| f.len()).unwrap_or(0);
236
237 let path = output_dir.join("edge_features.npy");
238 self.write_npy_2d_f64(&path, &features)?;
239
240 Ok(dim)
241 }
242
243 fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
245 let labels: Vec<i64> = graph
246 .node_anomaly_mask()
247 .iter()
248 .map(|&b| if b { 1 } else { 0 })
249 .collect();
250
251 let path = output_dir.join("node_labels.npy");
252 self.write_npy_1d_i64(&path, &labels)?;
253
254 Ok(())
255 }
256
257 fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
259 let labels: Vec<i64> = graph
260 .edge_anomaly_mask()
261 .iter()
262 .map(|&b| if b { 1 } else { 0 })
263 .collect();
264
265 let path = output_dir.join("edge_labels.npy");
266 self.write_npy_1d_i64(&path, &labels)?;
267
268 Ok(())
269 }
270
271 fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
273 let n = graph.node_count();
274 let mut rng = SimpleRng::new(self.config.common.seed);
275
276 let train_size = (n as f64 * self.config.common.train_ratio) as usize;
277 let val_size = (n as f64 * self.config.common.val_ratio) as usize;
278
279 let mut indices: Vec<usize> = (0..n).collect();
281 for i in (1..n).rev() {
282 let j = (rng.next() % (i as u64 + 1)) as usize;
283 indices.swap(i, j);
284 }
285
286 let mut train_mask = vec![false; n];
288 let mut val_mask = vec![false; n];
289 let mut test_mask = vec![false; n];
290
291 for (i, &idx) in indices.iter().enumerate() {
292 if i < train_size {
293 train_mask[idx] = true;
294 } else if i < train_size + val_size {
295 val_mask[idx] = true;
296 } else {
297 test_mask[idx] = true;
298 }
299 }
300
301 self.write_npy_1d_bool(&output_dir.join("train_mask.npy"), &train_mask)?;
303 self.write_npy_1d_bool(&output_dir.join("val_mask.npy"), &val_mask)?;
304 self.write_npy_1d_bool(&output_dir.join("test_mask.npy"), &test_mask)?;
305
306 Ok(())
307 }
308
309 fn export_node_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
311 let type_to_idx: HashMap<_, _> = graph
313 .nodes_by_type
314 .keys()
315 .enumerate()
316 .map(|(i, t)| (t.clone(), i as i64))
317 .collect();
318
319 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
321 node_ids.sort();
322
323 let type_indices: Vec<i64> = node_ids
325 .iter()
326 .map(|id| {
327 let node = graph.nodes.get(id).unwrap();
328 *type_to_idx.get(&node.node_type).unwrap_or(&0)
329 })
330 .collect();
331
332 let path = output_dir.join("node_type_indices.npy");
333 self.write_npy_1d_i64(&path, &type_indices)?;
334
335 Ok(())
336 }
337
338 fn export_edge_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
340 let type_to_idx: HashMap<_, _> = graph
342 .edges_by_type
343 .keys()
344 .enumerate()
345 .map(|(i, t)| (t.clone(), i as i64))
346 .collect();
347
348 let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
350 edge_ids.sort();
351
352 let type_indices: Vec<i64> = edge_ids
354 .iter()
355 .map(|id| {
356 let edge = graph.edges.get(id).unwrap();
357 *type_to_idx.get(&edge.edge_type).unwrap_or(&0)
358 })
359 .collect();
360
361 let path = output_dir.join("edge_type_indices.npy");
362 self.write_npy_1d_i64(&path, &type_indices)?;
363
364 Ok(())
365 }
366
367 fn write_npy_1d_i64(&self, path: &Path, data: &[i64]) -> std::io::Result<()> {
369 let file = File::create(path)?;
370 let mut writer = BufWriter::new(file);
371
372 let shape = format!("({},)", data.len());
374 self.write_npy_header(&mut writer, "<i8", &shape)?;
375
376 for &val in data {
378 writer.write_all(&val.to_le_bytes())?;
379 }
380
381 Ok(())
382 }
383
384 fn write_npy_1d_bool(&self, path: &Path, data: &[bool]) -> std::io::Result<()> {
386 let file = File::create(path)?;
387 let mut writer = BufWriter::new(file);
388
389 let shape = format!("({},)", data.len());
391 self.write_npy_header(&mut writer, "|b1", &shape)?;
392
393 for &val in data {
395 writer.write_all(&[if val { 1u8 } else { 0u8 }])?;
396 }
397
398 Ok(())
399 }
400
401 fn write_npy_2d_i64(&self, path: &Path, data: &[Vec<i64>]) -> std::io::Result<()> {
403 let file = File::create(path)?;
404 let mut writer = BufWriter::new(file);
405
406 let rows = data.len();
407 let cols = data.first().map(|r| r.len()).unwrap_or(0);
408
409 let shape = format!("({}, {})", rows, cols);
411 self.write_npy_header(&mut writer, "<i8", &shape)?;
412
413 for row in data {
415 for &val in row {
416 writer.write_all(&val.to_le_bytes())?;
417 }
418 for _ in row.len()..cols {
420 writer.write_all(&0_i64.to_le_bytes())?;
421 }
422 }
423
424 Ok(())
425 }
426
427 fn write_npy_2d_f64(&self, path: &Path, data: &[Vec<f64>]) -> std::io::Result<()> {
429 let file = File::create(path)?;
430 let mut writer = BufWriter::new(file);
431
432 let rows = data.len();
433 let cols = data.first().map(|r| r.len()).unwrap_or(0);
434
435 let shape = format!("({}, {})", rows, cols);
437 self.write_npy_header(&mut writer, "<f8", &shape)?;
438
439 for row in data {
441 for &val in row {
442 writer.write_all(&val.to_le_bytes())?;
443 }
444 for _ in row.len()..cols {
446 writer.write_all(&0.0_f64.to_le_bytes())?;
447 }
448 }
449
450 Ok(())
451 }
452
453 fn write_npy_header<W: Write>(
455 &self,
456 writer: &mut W,
457 dtype: &str,
458 shape: &str,
459 ) -> std::io::Result<()> {
460 writer.write_all(&[0x93])?; writer.write_all(b"NUMPY")?;
463 writer.write_all(&[0x01, 0x00])?; let header = format!(
467 "{{'descr': '{}', 'fortran_order': False, 'shape': {} }}",
468 dtype, shape
469 );
470
471 let header_len = header.len();
473 let total_len = 10 + header_len + 1; let padding = (64 - (total_len % 64)) % 64;
475 let padded_len = header_len + 1 + padding;
476
477 writer.write_all(&(padded_len as u16).to_le_bytes())?;
478 writer.write_all(header.as_bytes())?;
479 for _ in 0..padding {
480 writer.write_all(b" ")?;
481 }
482 writer.write_all(b"\n")?;
483
484 Ok(())
485 }
486
487 fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
489 let script = r#"#!/usr/bin/env python3
490"""
491DGL (Deep Graph Library) Data Loader
492
493Auto-generated loader for graph data exported from synth-graph.
494Supports both homogeneous and heterogeneous graph loading.
495"""
496
497import json
498import numpy as np
499from pathlib import Path
500
501try:
502 import torch
503 HAS_TORCH = True
504except ImportError:
505 HAS_TORCH = False
506 print("Warning: torch not installed. Install with: pip install torch")
507
508try:
509 import dgl
510 HAS_DGL = True
511except ImportError:
512 HAS_DGL = False
513 print("Warning: dgl not installed. Install with: pip install dgl")
514
515
516def load_graph(data_dir: str = ".") -> "dgl.DGLGraph":
517 """Load graph data into a DGL graph object.
518
519 Args:
520 data_dir: Directory containing the exported graph data.
521
522 Returns:
523 DGL graph with node features, edge features, and labels attached.
524 """
525 data_dir = Path(data_dir)
526
527 # Load metadata
528 with open(data_dir / "metadata.json") as f:
529 metadata = json.load(f)
530
531 # Load edge index (COO format: [num_edges, 2])
532 edge_index = np.load(data_dir / "edge_index.npy")
533 src = edge_index[:, 0]
534 dst = edge_index[:, 1]
535
536 num_nodes = metadata["num_nodes"]
537
538 if not HAS_DGL:
539 # Return dict if DGL not available
540 result = {
541 "src": src,
542 "dst": dst,
543 "num_nodes": num_nodes,
544 "metadata": metadata,
545 }
546
547 # Load optional arrays
548 if (data_dir / "node_features.npy").exists():
549 result["node_features"] = np.load(data_dir / "node_features.npy")
550 if (data_dir / "edge_features.npy").exists():
551 result["edge_features"] = np.load(data_dir / "edge_features.npy")
552 if (data_dir / "node_labels.npy").exists():
553 result["node_labels"] = np.load(data_dir / "node_labels.npy")
554 if (data_dir / "edge_labels.npy").exists():
555 result["edge_labels"] = np.load(data_dir / "edge_labels.npy")
556 if (data_dir / "train_mask.npy").exists():
557 result["train_mask"] = np.load(data_dir / "train_mask.npy")
558 result["val_mask"] = np.load(data_dir / "val_mask.npy")
559 result["test_mask"] = np.load(data_dir / "test_mask.npy")
560
561 return result
562
563 # Create DGL graph
564 g = dgl.graph((src, dst), num_nodes=num_nodes)
565
566 # Load and attach node features
567 node_features_path = data_dir / "node_features.npy"
568 if node_features_path.exists():
569 node_features = np.load(node_features_path)
570 if HAS_TORCH:
571 g.ndata['feat'] = torch.from_numpy(node_features).float()
572 else:
573 g.ndata['feat'] = node_features
574
575 # Load and attach edge features
576 edge_features_path = data_dir / "edge_features.npy"
577 if edge_features_path.exists():
578 edge_features = np.load(edge_features_path)
579 if HAS_TORCH:
580 g.edata['feat'] = torch.from_numpy(edge_features).float()
581 else:
582 g.edata['feat'] = edge_features
583
584 # Load and attach node labels
585 node_labels_path = data_dir / "node_labels.npy"
586 if node_labels_path.exists():
587 node_labels = np.load(node_labels_path)
588 if HAS_TORCH:
589 g.ndata['label'] = torch.from_numpy(node_labels).long()
590 else:
591 g.ndata['label'] = node_labels
592
593 # Load and attach edge labels
594 edge_labels_path = data_dir / "edge_labels.npy"
595 if edge_labels_path.exists():
596 edge_labels = np.load(edge_labels_path)
597 if HAS_TORCH:
598 g.edata['label'] = torch.from_numpy(edge_labels).long()
599 else:
600 g.edata['label'] = edge_labels
601
602 # Load and attach masks
603 if (data_dir / "train_mask.npy").exists():
604 train_mask = np.load(data_dir / "train_mask.npy")
605 val_mask = np.load(data_dir / "val_mask.npy")
606 test_mask = np.load(data_dir / "test_mask.npy")
607
608 if HAS_TORCH:
609 g.ndata['train_mask'] = torch.from_numpy(train_mask).bool()
610 g.ndata['val_mask'] = torch.from_numpy(val_mask).bool()
611 g.ndata['test_mask'] = torch.from_numpy(test_mask).bool()
612 else:
613 g.ndata['train_mask'] = train_mask
614 g.ndata['val_mask'] = val_mask
615 g.ndata['test_mask'] = test_mask
616
617 # Store metadata as graph attribute
618 g.metadata = metadata
619
620 return g
621
622
623def load_heterogeneous_graph(data_dir: str = ".") -> "dgl.DGLHeteroGraph":
624 """Load graph data into a DGL heterogeneous graph.
625
626 This function handles graphs with multiple node and edge types.
627
628 Args:
629 data_dir: Directory containing the exported graph data.
630
631 Returns:
632 DGL heterogeneous graph.
633 """
634 data_dir = Path(data_dir)
635
636 # Load metadata
637 with open(data_dir / "metadata.json") as f:
638 metadata = json.load(f)
639
640 if not metadata.get("is_heterogeneous", False):
641 print("Warning: Graph was not exported as heterogeneous. Using homogeneous loader.")
642 return load_graph(data_dir)
643
644 if not HAS_DGL:
645 raise ImportError("DGL is required for heterogeneous graph loading")
646
647 # Load edge index and type indices
648 edge_index = np.load(data_dir / "edge_index.npy")
649 edge_types = np.load(data_dir / "edge_type_indices.npy")
650 node_types = np.load(data_dir / "node_type_indices.npy")
651
652 # Get type names from metadata
653 node_type_names = list(metadata["node_types"].keys())
654 edge_type_names = list(metadata["edge_types"].keys())
655
656 # Build edge dict for heterogeneous graph
657 edge_dict = {}
658 for etype_idx, etype_name in enumerate(edge_type_names):
659 mask = edge_types == etype_idx
660 if mask.any():
661 src = edge_index[mask, 0]
662 dst = edge_index[mask, 1]
663 # For heterogeneous, we need to specify (src_type, edge_type, dst_type)
664 # Using simplified convention: (node_type, edge_type, node_type)
665 edge_dict[(node_type_names[0] if node_type_names else 'node',
666 etype_name,
667 node_type_names[0] if node_type_names else 'node')] = (src, dst)
668
669 # Create heterogeneous graph
670 g = dgl.heterograph(edge_dict) if edge_dict else dgl.graph(([], []))
671 g.metadata = metadata
672
673 return g
674
675
676def print_summary(data_dir: str = "."):
677 """Print summary of the graph data."""
678 data_dir = Path(data_dir)
679
680 with open(data_dir / "metadata.json") as f:
681 metadata = json.load(f)
682
683 print(f"Graph: {metadata['name']}")
684 print(f"Format: DGL ({metadata['edge_format']} edge format)")
685 print(f"Nodes: {metadata['num_nodes']}")
686 print(f"Edges: {metadata['num_edges']}")
687 print(f"Node feature dim: {metadata['node_feature_dim']}")
688 print(f"Edge feature dim: {metadata['edge_feature_dim']}")
689 print(f"Directed: {metadata['is_directed']}")
690 print(f"Heterogeneous: {metadata['is_heterogeneous']}")
691
692 if metadata['node_types']:
693 print(f"Node types: {metadata['node_types']}")
694 if metadata['edge_types']:
695 print(f"Edge types: {metadata['edge_types']}")
696
697 if metadata['statistics']:
698 print("\nStatistics:")
699 for key, value in metadata['statistics'].items():
700 print(f" {key}: {value:.4f}")
701
702 if HAS_DGL:
703 print("\nLoading graph...")
704 g = load_graph(data_dir)
705 if hasattr(g, 'num_nodes'):
706 print(f"DGL graph loaded: {g.num_nodes()} nodes, {g.num_edges()} edges")
707 if 'label' in g.ndata:
708 print(f"Anomalous nodes: {g.ndata['label'].sum().item()}")
709
710
711if __name__ == "__main__":
712 import sys
713 data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
714 print_summary(data_dir)
715"#;
716
717 let path = output_dir.join("load_graph.py");
718 let mut file = File::create(path)?;
719 file.write_all(script.as_bytes())?;
720
721 Ok(())
722 }
723
724 fn write_pickle_script(&self, output_dir: &Path) -> std::io::Result<()> {
726 let script = r#"#!/usr/bin/env python3
727"""
728DGL Graph Pickle Helper
729
730Utility to save and load DGL graphs as pickle files for faster subsequent loading.
731"""
732
733import pickle
734from pathlib import Path
735
736try:
737 import dgl
738 HAS_DGL = True
739except ImportError:
740 HAS_DGL = False
741
742
743def save_dgl_graph(graph, output_path: str):
744 """Save a DGL graph to a pickle file.
745
746 Args:
747 graph: DGL graph to save.
748 output_path: Path to save the pickle file.
749 """
750 output_path = Path(output_path)
751
752 # Save graph data
753 graph_data = {
754 'num_nodes': graph.num_nodes(),
755 'edges': graph.edges(),
756 'ndata': {k: v.numpy() if hasattr(v, 'numpy') else v
757 for k, v in graph.ndata.items()},
758 'edata': {k: v.numpy() if hasattr(v, 'numpy') else v
759 for k, v in graph.edata.items()},
760 'metadata': getattr(graph, 'metadata', {}),
761 }
762
763 with open(output_path, 'wb') as f:
764 pickle.dump(graph_data, f, protocol=pickle.HIGHEST_PROTOCOL)
765
766 print(f"Saved graph to {output_path}")
767
768
769def load_dgl_graph(input_path: str) -> "dgl.DGLGraph":
770 """Load a DGL graph from a pickle file.
771
772 Args:
773 input_path: Path to the pickle file.
774
775 Returns:
776 DGL graph.
777 """
778 if not HAS_DGL:
779 raise ImportError("DGL is required to load graphs")
780
781 import torch
782
783 input_path = Path(input_path)
784
785 with open(input_path, 'rb') as f:
786 graph_data = pickle.load(f)
787
788 # Recreate graph
789 src, dst = graph_data['edges']
790 g = dgl.graph((src, dst), num_nodes=graph_data['num_nodes'])
791
792 # Restore node data
793 for k, v in graph_data['ndata'].items():
794 g.ndata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
795
796 # Restore edge data
797 for k, v in graph_data['edata'].items():
798 g.edata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
799
800 # Restore metadata
801 g.metadata = graph_data.get('metadata', {})
802
803 return g
804
805
806def convert_to_pickle(data_dir: str, output_path: str = None):
807 """Convert exported graph data to pickle format for faster loading.
808
809 Args:
810 data_dir: Directory containing the exported graph data.
811 output_path: Path for output pickle file. Defaults to data_dir/graph.pkl.
812 """
813 from load_graph import load_graph
814
815 data_dir = Path(data_dir)
816 output_path = Path(output_path) if output_path else data_dir / "graph.pkl"
817
818 print(f"Loading graph from {data_dir}...")
819 g = load_graph(str(data_dir))
820
821 if isinstance(g, dict):
822 print("Error: DGL not available, cannot convert to pickle")
823 return
824
825 save_dgl_graph(g, str(output_path))
826 print(f"Graph saved to {output_path}")
827
828
829if __name__ == "__main__":
830 import sys
831
832 if len(sys.argv) < 2:
833 print("Usage:")
834 print(" python pickle_helper.py convert <data_dir> [output_path]")
835 print(" python pickle_helper.py load <pickle_path>")
836 sys.exit(1)
837
838 command = sys.argv[1]
839
840 if command == "convert":
841 data_dir = sys.argv[2] if len(sys.argv) > 2 else "."
842 output_path = sys.argv[3] if len(sys.argv) > 3 else None
843 convert_to_pickle(data_dir, output_path)
844 elif command == "load":
845 pickle_path = sys.argv[2]
846 g = load_dgl_graph(pickle_path)
847 print(f"Loaded graph: {g.num_nodes()} nodes, {g.num_edges()} edges")
848 else:
849 print(f"Unknown command: {command}")
850"#;
851
852 let path = output_dir.join("pickle_helper.py");
853 let mut file = File::create(path)?;
854 file.write_all(script.as_bytes())?;
855
856 Ok(())
857 }
858}
859
860struct SimpleRng {
862 state: u64,
863}
864
865impl SimpleRng {
866 fn new(seed: u64) -> Self {
867 Self {
868 state: if seed == 0 { 1 } else { seed },
869 }
870 }
871
872 fn next(&mut self) -> u64 {
873 let mut x = self.state;
874 x ^= x << 13;
875 x ^= x >> 7;
876 x ^= x << 17;
877 self.state = x;
878 x
879 }
880}
881
882#[cfg(test)]
883mod tests {
884 use super::*;
885 use crate::test_helpers::create_test_graph_with_company;
886 use tempfile::tempdir;
887
888 #[test]
889 fn test_dgl_export_basic() {
890 let graph = create_test_graph_with_company();
891 let dir = tempdir().unwrap();
892
893 let exporter = DGLExporter::new(DGLExportConfig::default());
894 let metadata = exporter.export(&graph, dir.path()).unwrap();
895
896 assert_eq!(metadata.common.num_nodes, 3);
897 assert_eq!(metadata.common.num_edges, 2);
898 assert_eq!(metadata.edge_format, "COO");
899 assert!(dir.path().join("edge_index.npy").exists());
900 assert!(dir.path().join("node_features.npy").exists());
901 assert!(dir.path().join("node_labels.npy").exists());
902 assert!(dir.path().join("metadata.json").exists());
903 assert!(dir.path().join("load_graph.py").exists());
904 assert!(dir.path().join("pickle_helper.py").exists());
905 }
906
907 #[test]
908 fn test_dgl_export_heterogeneous() {
909 let graph = create_test_graph_with_company();
910 let dir = tempdir().unwrap();
911
912 let config = DGLExportConfig {
913 heterogeneous: true,
914 ..Default::default()
915 };
916 let exporter = DGLExporter::new(config);
917 let metadata = exporter.export(&graph, dir.path()).unwrap();
918
919 assert!(metadata.is_heterogeneous);
920 assert!(dir.path().join("node_type_indices.npy").exists());
921 assert!(dir.path().join("edge_type_indices.npy").exists());
922 }
923
924 #[test]
925 fn test_dgl_export_masks() {
926 let graph = create_test_graph_with_company();
927 let dir = tempdir().unwrap();
928
929 let exporter = DGLExporter::new(DGLExportConfig::default());
930 let metadata = exporter.export(&graph, dir.path()).unwrap();
931
932 assert!(metadata
933 .common
934 .files
935 .contains(&"train_mask.npy".to_string()));
936 assert!(metadata.common.files.contains(&"val_mask.npy".to_string()));
937 assert!(metadata.common.files.contains(&"test_mask.npy".to_string()));
938 assert!(dir.path().join("train_mask.npy").exists());
939 assert!(dir.path().join("val_mask.npy").exists());
940 assert!(dir.path().join("test_mask.npy").exists());
941 }
942
943 #[test]
944 fn test_dgl_coo_format() {
945 let graph = create_test_graph_with_company();
946 let dir = tempdir().unwrap();
947
948 let exporter = DGLExporter::new(DGLExportConfig::default());
949 exporter.export(&graph, dir.path()).unwrap();
950
951 let edge_path = dir.path().join("edge_index.npy");
954 assert!(edge_path.exists());
955
956 let metadata_path = dir.path().join("metadata.json");
958 let metadata: DGLMetadata =
959 serde_json::from_reader(File::open(metadata_path).unwrap()).unwrap();
960 assert_eq!(metadata.edge_format, "COO");
961 }
962
963 #[test]
964 fn test_dgl_export_no_masks() {
965 let graph = create_test_graph_with_company();
966 let dir = tempdir().unwrap();
967
968 let config = DGLExportConfig {
969 common: CommonExportConfig {
970 export_masks: false,
971 ..Default::default()
972 },
973 ..Default::default()
974 };
975 let exporter = DGLExporter::new(config);
976 let metadata = exporter.export(&graph, dir.path()).unwrap();
977
978 assert!(!metadata
979 .common
980 .files
981 .contains(&"train_mask.npy".to_string()));
982 assert!(!dir.path().join("train_mask.npy").exists());
983 }
984
985 #[test]
986 fn test_dgl_export_minimal() {
987 let graph = create_test_graph_with_company();
988 let dir = tempdir().unwrap();
989
990 let config = DGLExportConfig {
991 common: CommonExportConfig {
992 export_node_features: false,
993 export_edge_features: false,
994 export_node_labels: false,
995 export_edge_labels: false,
996 export_masks: false,
997 ..Default::default()
998 },
999 include_pickle_script: false,
1000 ..Default::default()
1001 };
1002 let exporter = DGLExporter::new(config);
1003 let metadata = exporter.export(&graph, dir.path()).unwrap();
1004
1005 assert_eq!(metadata.common.files.len(), 1); assert!(dir.path().join("edge_index.npy").exists());
1008 assert!(dir.path().join("load_graph.py").exists()); assert!(dir.path().join("metadata.json").exists());
1010 assert!(!dir.path().join("pickle_helper.py").exists());
1011 }
1012
1013 #[test]
1014 fn test_dgl_statistics() {
1015 let graph = create_test_graph_with_company();
1016 let dir = tempdir().unwrap();
1017
1018 let exporter = DGLExporter::new(DGLExportConfig::default());
1019 let metadata = exporter.export(&graph, dir.path()).unwrap();
1020
1021 assert!(metadata.common.statistics.contains_key("density"));
1023 assert!(metadata
1024 .common
1025 .statistics
1026 .contains_key("anomalous_node_ratio"));
1027 assert!(metadata
1028 .common
1029 .statistics
1030 .contains_key("anomalous_edge_ratio"));
1031
1032 let node_ratio = metadata
1034 .common
1035 .statistics
1036 .get("anomalous_node_ratio")
1037 .unwrap();
1038 assert!((*node_ratio - 1.0 / 3.0).abs() < 0.01);
1039 }
1040}