1use std::collections::HashMap;
25use std::fs::{self, File};
26use std::io::{BufWriter, Write};
27use std::path::Path;
28
29use serde::{Deserialize, Serialize};
30
31use crate::models::Graph;
32
33#[derive(Debug, Clone)]
35pub struct DGLExportConfig {
36 pub export_node_features: bool,
38 pub export_edge_features: bool,
40 pub export_node_labels: bool,
42 pub export_edge_labels: bool,
44 pub export_masks: bool,
46 pub train_ratio: f64,
48 pub val_ratio: f64,
50 pub seed: u64,
52 pub heterogeneous: bool,
54 pub include_pickle_script: bool,
56}
57
58impl Default for DGLExportConfig {
59 fn default() -> Self {
60 Self {
61 export_node_features: true,
62 export_edge_features: true,
63 export_node_labels: true,
64 export_edge_labels: true,
65 export_masks: true,
66 train_ratio: 0.7,
67 val_ratio: 0.15,
68 seed: 42,
69 heterogeneous: false,
70 include_pickle_script: true,
71 }
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct DGLMetadata {
78 pub name: String,
80 pub num_nodes: usize,
82 pub num_edges: usize,
84 pub node_feature_dim: usize,
86 pub edge_feature_dim: usize,
88 pub num_node_classes: usize,
90 pub num_edge_classes: usize,
92 pub node_types: HashMap<String, usize>,
94 pub edge_types: HashMap<String, usize>,
96 pub is_directed: bool,
98 pub is_heterogeneous: bool,
100 pub edge_format: String,
102 pub files: Vec<String>,
104 pub statistics: HashMap<String, f64>,
106}
107
108pub struct DGLExporter {
110 config: DGLExportConfig,
111}
112
113impl DGLExporter {
114 pub fn new(config: DGLExportConfig) -> Self {
116 Self { config }
117 }
118
119 pub fn export(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<DGLMetadata> {
121 fs::create_dir_all(output_dir)?;
122
123 let mut files = Vec::new();
124 let mut statistics = HashMap::new();
125
126 self.export_edge_index(graph, output_dir)?;
128 files.push("edge_index.npy".to_string());
129
130 if self.config.export_node_features {
132 let dim = self.export_node_features(graph, output_dir)?;
133 files.push("node_features.npy".to_string());
134 statistics.insert("node_feature_dim".to_string(), dim as f64);
135 }
136
137 if self.config.export_edge_features {
139 let dim = self.export_edge_features(graph, output_dir)?;
140 files.push("edge_features.npy".to_string());
141 statistics.insert("edge_feature_dim".to_string(), dim as f64);
142 }
143
144 if self.config.export_node_labels {
146 self.export_node_labels(graph, output_dir)?;
147 files.push("node_labels.npy".to_string());
148 }
149
150 if self.config.export_edge_labels {
152 self.export_edge_labels(graph, output_dir)?;
153 files.push("edge_labels.npy".to_string());
154 }
155
156 if self.config.export_masks {
158 self.export_masks(graph, output_dir)?;
159 files.push("train_mask.npy".to_string());
160 files.push("val_mask.npy".to_string());
161 files.push("test_mask.npy".to_string());
162 }
163
164 if self.config.heterogeneous {
166 self.export_node_types(graph, output_dir)?;
167 files.push("node_type_indices.npy".to_string());
168 self.export_edge_types(graph, output_dir)?;
169 files.push("edge_type_indices.npy".to_string());
170 }
171
172 let node_types: HashMap<String, usize> = graph
174 .nodes_by_type
175 .iter()
176 .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
177 .collect();
178
179 let edge_types: HashMap<String, usize> = graph
180 .edges_by_type
181 .iter()
182 .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
183 .collect();
184
185 statistics.insert("density".to_string(), graph.metadata.density);
187 statistics.insert(
188 "anomalous_node_ratio".to_string(),
189 graph.metadata.anomalous_node_count as f64 / graph.node_count().max(1) as f64,
190 );
191 statistics.insert(
192 "anomalous_edge_ratio".to_string(),
193 graph.metadata.anomalous_edge_count as f64 / graph.edge_count().max(1) as f64,
194 );
195
196 let metadata = DGLMetadata {
198 name: graph.name.clone(),
199 num_nodes: graph.node_count(),
200 num_edges: graph.edge_count(),
201 node_feature_dim: graph.metadata.node_feature_dim,
202 edge_feature_dim: graph.metadata.edge_feature_dim,
203 num_node_classes: 2, num_edge_classes: 2,
205 node_types,
206 edge_types,
207 is_directed: true,
208 is_heterogeneous: self.config.heterogeneous,
209 edge_format: "COO".to_string(),
210 files,
211 statistics,
212 };
213
214 let metadata_path = output_dir.join("metadata.json");
216 let file = File::create(metadata_path)?;
217 serde_json::to_writer_pretty(file, &metadata)?;
218
219 self.write_loader_script(output_dir)?;
221
222 if self.config.include_pickle_script {
224 self.write_pickle_script(output_dir)?;
225 }
226
227 Ok(metadata)
228 }
229
230 fn export_edge_index(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
232 let (sources, targets) = graph.edge_index();
233
234 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
236 node_ids.sort();
237 let id_to_idx: HashMap<_, _> = node_ids
238 .iter()
239 .enumerate()
240 .map(|(i, &id)| (id, i))
241 .collect();
242
243 let num_edges = sources.len();
245 let coo_data: Vec<Vec<i64>> = (0..num_edges)
246 .map(|i| {
247 let src = *id_to_idx.get(&sources[i]).unwrap_or(&0) as i64;
248 let dst = *id_to_idx.get(&targets[i]).unwrap_or(&0) as i64;
249 vec![src, dst]
250 })
251 .collect();
252
253 let path = output_dir.join("edge_index.npy");
255 self.write_npy_2d_i64(&path, &coo_data)?;
256
257 Ok(())
258 }
259
260 fn export_node_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
262 let features = graph.node_features();
263 let dim = features.first().map(|f| f.len()).unwrap_or(0);
264
265 let path = output_dir.join("node_features.npy");
266 self.write_npy_2d_f64(&path, &features)?;
267
268 Ok(dim)
269 }
270
271 fn export_edge_features(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<usize> {
273 let features = graph.edge_features();
274 let dim = features.first().map(|f| f.len()).unwrap_or(0);
275
276 let path = output_dir.join("edge_features.npy");
277 self.write_npy_2d_f64(&path, &features)?;
278
279 Ok(dim)
280 }
281
282 fn export_node_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
284 let labels: Vec<i64> = graph
285 .node_anomaly_mask()
286 .iter()
287 .map(|&b| if b { 1 } else { 0 })
288 .collect();
289
290 let path = output_dir.join("node_labels.npy");
291 self.write_npy_1d_i64(&path, &labels)?;
292
293 Ok(())
294 }
295
296 fn export_edge_labels(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
298 let labels: Vec<i64> = graph
299 .edge_anomaly_mask()
300 .iter()
301 .map(|&b| if b { 1 } else { 0 })
302 .collect();
303
304 let path = output_dir.join("edge_labels.npy");
305 self.write_npy_1d_i64(&path, &labels)?;
306
307 Ok(())
308 }
309
310 fn export_masks(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
312 let n = graph.node_count();
313 let mut rng = SimpleRng::new(self.config.seed);
314
315 let train_size = (n as f64 * self.config.train_ratio) as usize;
316 let val_size = (n as f64 * self.config.val_ratio) as usize;
317
318 let mut indices: Vec<usize> = (0..n).collect();
320 for i in (1..n).rev() {
321 let j = (rng.next() % (i as u64 + 1)) as usize;
322 indices.swap(i, j);
323 }
324
325 let mut train_mask = vec![false; n];
327 let mut val_mask = vec![false; n];
328 let mut test_mask = vec![false; n];
329
330 for (i, &idx) in indices.iter().enumerate() {
331 if i < train_size {
332 train_mask[idx] = true;
333 } else if i < train_size + val_size {
334 val_mask[idx] = true;
335 } else {
336 test_mask[idx] = true;
337 }
338 }
339
340 self.write_npy_1d_bool(&output_dir.join("train_mask.npy"), &train_mask)?;
342 self.write_npy_1d_bool(&output_dir.join("val_mask.npy"), &val_mask)?;
343 self.write_npy_1d_bool(&output_dir.join("test_mask.npy"), &test_mask)?;
344
345 Ok(())
346 }
347
348 fn export_node_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
350 let type_to_idx: HashMap<_, _> = graph
352 .nodes_by_type
353 .keys()
354 .enumerate()
355 .map(|(i, t)| (t.clone(), i as i64))
356 .collect();
357
358 let mut node_ids: Vec<_> = graph.nodes.keys().copied().collect();
360 node_ids.sort();
361
362 let type_indices: Vec<i64> = node_ids
364 .iter()
365 .map(|id| {
366 let node = graph.nodes.get(id).unwrap();
367 *type_to_idx.get(&node.node_type).unwrap_or(&0)
368 })
369 .collect();
370
371 let path = output_dir.join("node_type_indices.npy");
372 self.write_npy_1d_i64(&path, &type_indices)?;
373
374 Ok(())
375 }
376
377 fn export_edge_types(&self, graph: &Graph, output_dir: &Path) -> std::io::Result<()> {
379 let type_to_idx: HashMap<_, _> = graph
381 .edges_by_type
382 .keys()
383 .enumerate()
384 .map(|(i, t)| (t.clone(), i as i64))
385 .collect();
386
387 let mut edge_ids: Vec<_> = graph.edges.keys().copied().collect();
389 edge_ids.sort();
390
391 let type_indices: Vec<i64> = edge_ids
393 .iter()
394 .map(|id| {
395 let edge = graph.edges.get(id).unwrap();
396 *type_to_idx.get(&edge.edge_type).unwrap_or(&0)
397 })
398 .collect();
399
400 let path = output_dir.join("edge_type_indices.npy");
401 self.write_npy_1d_i64(&path, &type_indices)?;
402
403 Ok(())
404 }
405
406 fn write_npy_1d_i64(&self, path: &Path, data: &[i64]) -> std::io::Result<()> {
408 let file = File::create(path)?;
409 let mut writer = BufWriter::new(file);
410
411 let shape = format!("({},)", data.len());
413 self.write_npy_header(&mut writer, "<i8", &shape)?;
414
415 for &val in data {
417 writer.write_all(&val.to_le_bytes())?;
418 }
419
420 Ok(())
421 }
422
423 fn write_npy_1d_bool(&self, path: &Path, data: &[bool]) -> std::io::Result<()> {
425 let file = File::create(path)?;
426 let mut writer = BufWriter::new(file);
427
428 let shape = format!("({},)", data.len());
430 self.write_npy_header(&mut writer, "|b1", &shape)?;
431
432 for &val in data {
434 writer.write_all(&[if val { 1u8 } else { 0u8 }])?;
435 }
436
437 Ok(())
438 }
439
440 fn write_npy_2d_i64(&self, path: &Path, data: &[Vec<i64>]) -> std::io::Result<()> {
442 let file = File::create(path)?;
443 let mut writer = BufWriter::new(file);
444
445 let rows = data.len();
446 let cols = data.first().map(|r| r.len()).unwrap_or(0);
447
448 let shape = format!("({}, {})", rows, cols);
450 self.write_npy_header(&mut writer, "<i8", &shape)?;
451
452 for row in data {
454 for &val in row {
455 writer.write_all(&val.to_le_bytes())?;
456 }
457 for _ in row.len()..cols {
459 writer.write_all(&0_i64.to_le_bytes())?;
460 }
461 }
462
463 Ok(())
464 }
465
466 fn write_npy_2d_f64(&self, path: &Path, data: &[Vec<f64>]) -> std::io::Result<()> {
468 let file = File::create(path)?;
469 let mut writer = BufWriter::new(file);
470
471 let rows = data.len();
472 let cols = data.first().map(|r| r.len()).unwrap_or(0);
473
474 let shape = format!("({}, {})", rows, cols);
476 self.write_npy_header(&mut writer, "<f8", &shape)?;
477
478 for row in data {
480 for &val in row {
481 writer.write_all(&val.to_le_bytes())?;
482 }
483 for _ in row.len()..cols {
485 writer.write_all(&0.0_f64.to_le_bytes())?;
486 }
487 }
488
489 Ok(())
490 }
491
492 fn write_npy_header<W: Write>(
494 &self,
495 writer: &mut W,
496 dtype: &str,
497 shape: &str,
498 ) -> std::io::Result<()> {
499 writer.write_all(&[0x93])?; writer.write_all(b"NUMPY")?;
502 writer.write_all(&[0x01, 0x00])?; let header = format!(
506 "{{'descr': '{}', 'fortran_order': False, 'shape': {} }}",
507 dtype, shape
508 );
509
510 let header_len = header.len();
512 let total_len = 10 + header_len + 1; let padding = (64 - (total_len % 64)) % 64;
514 let padded_len = header_len + 1 + padding;
515
516 writer.write_all(&(padded_len as u16).to_le_bytes())?;
517 writer.write_all(header.as_bytes())?;
518 for _ in 0..padding {
519 writer.write_all(b" ")?;
520 }
521 writer.write_all(b"\n")?;
522
523 Ok(())
524 }
525
526 fn write_loader_script(&self, output_dir: &Path) -> std::io::Result<()> {
528 let script = r#"#!/usr/bin/env python3
529"""
530DGL (Deep Graph Library) Data Loader
531
532Auto-generated loader for graph data exported from synth-graph.
533Supports both homogeneous and heterogeneous graph loading.
534"""
535
536import json
537import numpy as np
538from pathlib import Path
539
540try:
541 import torch
542 HAS_TORCH = True
543except ImportError:
544 HAS_TORCH = False
545 print("Warning: torch not installed. Install with: pip install torch")
546
547try:
548 import dgl
549 HAS_DGL = True
550except ImportError:
551 HAS_DGL = False
552 print("Warning: dgl not installed. Install with: pip install dgl")
553
554
555def load_graph(data_dir: str = ".") -> "dgl.DGLGraph":
556 """Load graph data into a DGL graph object.
557
558 Args:
559 data_dir: Directory containing the exported graph data.
560
561 Returns:
562 DGL graph with node features, edge features, and labels attached.
563 """
564 data_dir = Path(data_dir)
565
566 # Load metadata
567 with open(data_dir / "metadata.json") as f:
568 metadata = json.load(f)
569
570 # Load edge index (COO format: [num_edges, 2])
571 edge_index = np.load(data_dir / "edge_index.npy")
572 src = edge_index[:, 0]
573 dst = edge_index[:, 1]
574
575 num_nodes = metadata["num_nodes"]
576
577 if not HAS_DGL:
578 # Return dict if DGL not available
579 result = {
580 "src": src,
581 "dst": dst,
582 "num_nodes": num_nodes,
583 "metadata": metadata,
584 }
585
586 # Load optional arrays
587 if (data_dir / "node_features.npy").exists():
588 result["node_features"] = np.load(data_dir / "node_features.npy")
589 if (data_dir / "edge_features.npy").exists():
590 result["edge_features"] = np.load(data_dir / "edge_features.npy")
591 if (data_dir / "node_labels.npy").exists():
592 result["node_labels"] = np.load(data_dir / "node_labels.npy")
593 if (data_dir / "edge_labels.npy").exists():
594 result["edge_labels"] = np.load(data_dir / "edge_labels.npy")
595 if (data_dir / "train_mask.npy").exists():
596 result["train_mask"] = np.load(data_dir / "train_mask.npy")
597 result["val_mask"] = np.load(data_dir / "val_mask.npy")
598 result["test_mask"] = np.load(data_dir / "test_mask.npy")
599
600 return result
601
602 # Create DGL graph
603 g = dgl.graph((src, dst), num_nodes=num_nodes)
604
605 # Load and attach node features
606 node_features_path = data_dir / "node_features.npy"
607 if node_features_path.exists():
608 node_features = np.load(node_features_path)
609 if HAS_TORCH:
610 g.ndata['feat'] = torch.from_numpy(node_features).float()
611 else:
612 g.ndata['feat'] = node_features
613
614 # Load and attach edge features
615 edge_features_path = data_dir / "edge_features.npy"
616 if edge_features_path.exists():
617 edge_features = np.load(edge_features_path)
618 if HAS_TORCH:
619 g.edata['feat'] = torch.from_numpy(edge_features).float()
620 else:
621 g.edata['feat'] = edge_features
622
623 # Load and attach node labels
624 node_labels_path = data_dir / "node_labels.npy"
625 if node_labels_path.exists():
626 node_labels = np.load(node_labels_path)
627 if HAS_TORCH:
628 g.ndata['label'] = torch.from_numpy(node_labels).long()
629 else:
630 g.ndata['label'] = node_labels
631
632 # Load and attach edge labels
633 edge_labels_path = data_dir / "edge_labels.npy"
634 if edge_labels_path.exists():
635 edge_labels = np.load(edge_labels_path)
636 if HAS_TORCH:
637 g.edata['label'] = torch.from_numpy(edge_labels).long()
638 else:
639 g.edata['label'] = edge_labels
640
641 # Load and attach masks
642 if (data_dir / "train_mask.npy").exists():
643 train_mask = np.load(data_dir / "train_mask.npy")
644 val_mask = np.load(data_dir / "val_mask.npy")
645 test_mask = np.load(data_dir / "test_mask.npy")
646
647 if HAS_TORCH:
648 g.ndata['train_mask'] = torch.from_numpy(train_mask).bool()
649 g.ndata['val_mask'] = torch.from_numpy(val_mask).bool()
650 g.ndata['test_mask'] = torch.from_numpy(test_mask).bool()
651 else:
652 g.ndata['train_mask'] = train_mask
653 g.ndata['val_mask'] = val_mask
654 g.ndata['test_mask'] = test_mask
655
656 # Store metadata as graph attribute
657 g.metadata = metadata
658
659 return g
660
661
662def load_heterogeneous_graph(data_dir: str = ".") -> "dgl.DGLHeteroGraph":
663 """Load graph data into a DGL heterogeneous graph.
664
665 This function handles graphs with multiple node and edge types.
666
667 Args:
668 data_dir: Directory containing the exported graph data.
669
670 Returns:
671 DGL heterogeneous graph.
672 """
673 data_dir = Path(data_dir)
674
675 # Load metadata
676 with open(data_dir / "metadata.json") as f:
677 metadata = json.load(f)
678
679 if not metadata.get("is_heterogeneous", False):
680 print("Warning: Graph was not exported as heterogeneous. Using homogeneous loader.")
681 return load_graph(data_dir)
682
683 if not HAS_DGL:
684 raise ImportError("DGL is required for heterogeneous graph loading")
685
686 # Load edge index and type indices
687 edge_index = np.load(data_dir / "edge_index.npy")
688 edge_types = np.load(data_dir / "edge_type_indices.npy")
689 node_types = np.load(data_dir / "node_type_indices.npy")
690
691 # Get type names from metadata
692 node_type_names = list(metadata["node_types"].keys())
693 edge_type_names = list(metadata["edge_types"].keys())
694
695 # Build edge dict for heterogeneous graph
696 edge_dict = {}
697 for etype_idx, etype_name in enumerate(edge_type_names):
698 mask = edge_types == etype_idx
699 if mask.any():
700 src = edge_index[mask, 0]
701 dst = edge_index[mask, 1]
702 # For heterogeneous, we need to specify (src_type, edge_type, dst_type)
703 # Using simplified convention: (node_type, edge_type, node_type)
704 edge_dict[(node_type_names[0] if node_type_names else 'node',
705 etype_name,
706 node_type_names[0] if node_type_names else 'node')] = (src, dst)
707
708 # Create heterogeneous graph
709 g = dgl.heterograph(edge_dict) if edge_dict else dgl.graph(([], []))
710 g.metadata = metadata
711
712 return g
713
714
715def print_summary(data_dir: str = "."):
716 """Print summary of the graph data."""
717 data_dir = Path(data_dir)
718
719 with open(data_dir / "metadata.json") as f:
720 metadata = json.load(f)
721
722 print(f"Graph: {metadata['name']}")
723 print(f"Format: DGL ({metadata['edge_format']} edge format)")
724 print(f"Nodes: {metadata['num_nodes']}")
725 print(f"Edges: {metadata['num_edges']}")
726 print(f"Node feature dim: {metadata['node_feature_dim']}")
727 print(f"Edge feature dim: {metadata['edge_feature_dim']}")
728 print(f"Directed: {metadata['is_directed']}")
729 print(f"Heterogeneous: {metadata['is_heterogeneous']}")
730
731 if metadata['node_types']:
732 print(f"Node types: {metadata['node_types']}")
733 if metadata['edge_types']:
734 print(f"Edge types: {metadata['edge_types']}")
735
736 if metadata['statistics']:
737 print("\nStatistics:")
738 for key, value in metadata['statistics'].items():
739 print(f" {key}: {value:.4f}")
740
741 if HAS_DGL:
742 print("\nLoading graph...")
743 g = load_graph(data_dir)
744 if hasattr(g, 'num_nodes'):
745 print(f"DGL graph loaded: {g.num_nodes()} nodes, {g.num_edges()} edges")
746 if 'label' in g.ndata:
747 print(f"Anomalous nodes: {g.ndata['label'].sum().item()}")
748
749
750if __name__ == "__main__":
751 import sys
752 data_dir = sys.argv[1] if len(sys.argv) > 1 else "."
753 print_summary(data_dir)
754"#;
755
756 let path = output_dir.join("load_graph.py");
757 let mut file = File::create(path)?;
758 file.write_all(script.as_bytes())?;
759
760 Ok(())
761 }
762
763 fn write_pickle_script(&self, output_dir: &Path) -> std::io::Result<()> {
765 let script = r#"#!/usr/bin/env python3
766"""
767DGL Graph Pickle Helper
768
769Utility to save and load DGL graphs as pickle files for faster subsequent loading.
770"""
771
772import pickle
773from pathlib import Path
774
775try:
776 import dgl
777 HAS_DGL = True
778except ImportError:
779 HAS_DGL = False
780
781
782def save_dgl_graph(graph, output_path: str):
783 """Save a DGL graph to a pickle file.
784
785 Args:
786 graph: DGL graph to save.
787 output_path: Path to save the pickle file.
788 """
789 output_path = Path(output_path)
790
791 # Save graph data
792 graph_data = {
793 'num_nodes': graph.num_nodes(),
794 'edges': graph.edges(),
795 'ndata': {k: v.numpy() if hasattr(v, 'numpy') else v
796 for k, v in graph.ndata.items()},
797 'edata': {k: v.numpy() if hasattr(v, 'numpy') else v
798 for k, v in graph.edata.items()},
799 'metadata': getattr(graph, 'metadata', {}),
800 }
801
802 with open(output_path, 'wb') as f:
803 pickle.dump(graph_data, f, protocol=pickle.HIGHEST_PROTOCOL)
804
805 print(f"Saved graph to {output_path}")
806
807
808def load_dgl_graph(input_path: str) -> "dgl.DGLGraph":
809 """Load a DGL graph from a pickle file.
810
811 Args:
812 input_path: Path to the pickle file.
813
814 Returns:
815 DGL graph.
816 """
817 if not HAS_DGL:
818 raise ImportError("DGL is required to load graphs")
819
820 import torch
821
822 input_path = Path(input_path)
823
824 with open(input_path, 'rb') as f:
825 graph_data = pickle.load(f)
826
827 # Recreate graph
828 src, dst = graph_data['edges']
829 g = dgl.graph((src, dst), num_nodes=graph_data['num_nodes'])
830
831 # Restore node data
832 for k, v in graph_data['ndata'].items():
833 g.ndata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
834
835 # Restore edge data
836 for k, v in graph_data['edata'].items():
837 g.edata[k] = torch.from_numpy(v) if hasattr(v, 'dtype') else v
838
839 # Restore metadata
840 g.metadata = graph_data.get('metadata', {})
841
842 return g
843
844
845def convert_to_pickle(data_dir: str, output_path: str = None):
846 """Convert exported graph data to pickle format for faster loading.
847
848 Args:
849 data_dir: Directory containing the exported graph data.
850 output_path: Path for output pickle file. Defaults to data_dir/graph.pkl.
851 """
852 from load_graph import load_graph
853
854 data_dir = Path(data_dir)
855 output_path = Path(output_path) if output_path else data_dir / "graph.pkl"
856
857 print(f"Loading graph from {data_dir}...")
858 g = load_graph(str(data_dir))
859
860 if isinstance(g, dict):
861 print("Error: DGL not available, cannot convert to pickle")
862 return
863
864 save_dgl_graph(g, str(output_path))
865 print(f"Graph saved to {output_path}")
866
867
868if __name__ == "__main__":
869 import sys
870
871 if len(sys.argv) < 2:
872 print("Usage:")
873 print(" python pickle_helper.py convert <data_dir> [output_path]")
874 print(" python pickle_helper.py load <pickle_path>")
875 sys.exit(1)
876
877 command = sys.argv[1]
878
879 if command == "convert":
880 data_dir = sys.argv[2] if len(sys.argv) > 2 else "."
881 output_path = sys.argv[3] if len(sys.argv) > 3 else None
882 convert_to_pickle(data_dir, output_path)
883 elif command == "load":
884 pickle_path = sys.argv[2]
885 g = load_dgl_graph(pickle_path)
886 print(f"Loaded graph: {g.num_nodes()} nodes, {g.num_edges()} edges")
887 else:
888 print(f"Unknown command: {command}")
889"#;
890
891 let path = output_dir.join("pickle_helper.py");
892 let mut file = File::create(path)?;
893 file.write_all(script.as_bytes())?;
894
895 Ok(())
896 }
897}
898
899struct SimpleRng {
901 state: u64,
902}
903
904impl SimpleRng {
905 fn new(seed: u64) -> Self {
906 Self {
907 state: if seed == 0 { 1 } else { seed },
908 }
909 }
910
911 fn next(&mut self) -> u64 {
912 let mut x = self.state;
913 x ^= x << 13;
914 x ^= x >> 7;
915 x ^= x << 17;
916 self.state = x;
917 x
918 }
919}
920
921#[cfg(test)]
922mod tests {
923 use super::*;
924 use crate::models::{EdgeType, GraphEdge, GraphNode, GraphType, NodeType};
925 use tempfile::tempdir;
926
927 fn create_test_graph() -> Graph {
928 let mut graph = Graph::new("test_dgl", GraphType::Transaction);
929
930 let n1 = graph.add_node(
931 GraphNode::new(0, NodeType::Account, "1000".to_string(), "Cash".to_string())
932 .with_feature(0.5),
933 );
934 let n2 = graph.add_node(
935 GraphNode::new(0, NodeType::Account, "2000".to_string(), "AP".to_string())
936 .with_feature(0.8),
937 );
938 let n3 = graph.add_node(
939 GraphNode::new(
940 0,
941 NodeType::Company,
942 "ACME".to_string(),
943 "ACME Corp".to_string(),
944 )
945 .with_feature(0.3)
946 .as_anomaly("fraud"),
947 );
948
949 graph.add_edge(
950 GraphEdge::new(0, n1, n2, EdgeType::Transaction)
951 .with_weight(1000.0)
952 .with_feature(6.9),
953 );
954
955 graph.add_edge(
956 GraphEdge::new(0, n2, n3, EdgeType::Ownership)
957 .with_weight(100.0)
958 .with_feature(4.6),
959 );
960
961 graph.compute_statistics();
962 graph
963 }
964
965 #[test]
966 fn test_dgl_export_basic() {
967 let graph = create_test_graph();
968 let dir = tempdir().unwrap();
969
970 let exporter = DGLExporter::new(DGLExportConfig::default());
971 let metadata = exporter.export(&graph, dir.path()).unwrap();
972
973 assert_eq!(metadata.num_nodes, 3);
974 assert_eq!(metadata.num_edges, 2);
975 assert_eq!(metadata.edge_format, "COO");
976 assert!(dir.path().join("edge_index.npy").exists());
977 assert!(dir.path().join("node_features.npy").exists());
978 assert!(dir.path().join("node_labels.npy").exists());
979 assert!(dir.path().join("metadata.json").exists());
980 assert!(dir.path().join("load_graph.py").exists());
981 assert!(dir.path().join("pickle_helper.py").exists());
982 }
983
984 #[test]
985 fn test_dgl_export_heterogeneous() {
986 let graph = create_test_graph();
987 let dir = tempdir().unwrap();
988
989 let config = DGLExportConfig {
990 heterogeneous: true,
991 ..Default::default()
992 };
993 let exporter = DGLExporter::new(config);
994 let metadata = exporter.export(&graph, dir.path()).unwrap();
995
996 assert!(metadata.is_heterogeneous);
997 assert!(dir.path().join("node_type_indices.npy").exists());
998 assert!(dir.path().join("edge_type_indices.npy").exists());
999 }
1000
1001 #[test]
1002 fn test_dgl_export_masks() {
1003 let graph = create_test_graph();
1004 let dir = tempdir().unwrap();
1005
1006 let exporter = DGLExporter::new(DGLExportConfig::default());
1007 let metadata = exporter.export(&graph, dir.path()).unwrap();
1008
1009 assert!(metadata.files.contains(&"train_mask.npy".to_string()));
1010 assert!(metadata.files.contains(&"val_mask.npy".to_string()));
1011 assert!(metadata.files.contains(&"test_mask.npy".to_string()));
1012 assert!(dir.path().join("train_mask.npy").exists());
1013 assert!(dir.path().join("val_mask.npy").exists());
1014 assert!(dir.path().join("test_mask.npy").exists());
1015 }
1016
1017 #[test]
1018 fn test_dgl_coo_format() {
1019 let graph = create_test_graph();
1020 let dir = tempdir().unwrap();
1021
1022 let exporter = DGLExporter::new(DGLExportConfig::default());
1023 exporter.export(&graph, dir.path()).unwrap();
1024
1025 let edge_path = dir.path().join("edge_index.npy");
1028 assert!(edge_path.exists());
1029
1030 let metadata_path = dir.path().join("metadata.json");
1032 let metadata: DGLMetadata =
1033 serde_json::from_reader(File::open(metadata_path).unwrap()).unwrap();
1034 assert_eq!(metadata.edge_format, "COO");
1035 }
1036
1037 #[test]
1038 fn test_dgl_export_no_masks() {
1039 let graph = create_test_graph();
1040 let dir = tempdir().unwrap();
1041
1042 let config = DGLExportConfig {
1043 export_masks: false,
1044 ..Default::default()
1045 };
1046 let exporter = DGLExporter::new(config);
1047 let metadata = exporter.export(&graph, dir.path()).unwrap();
1048
1049 assert!(!metadata.files.contains(&"train_mask.npy".to_string()));
1050 assert!(!dir.path().join("train_mask.npy").exists());
1051 }
1052
1053 #[test]
1054 fn test_dgl_export_minimal() {
1055 let graph = create_test_graph();
1056 let dir = tempdir().unwrap();
1057
1058 let config = DGLExportConfig {
1059 export_node_features: false,
1060 export_edge_features: false,
1061 export_node_labels: false,
1062 export_edge_labels: false,
1063 export_masks: false,
1064 include_pickle_script: false,
1065 ..Default::default()
1066 };
1067 let exporter = DGLExporter::new(config);
1068 let metadata = exporter.export(&graph, dir.path()).unwrap();
1069
1070 assert_eq!(metadata.files.len(), 1); assert!(dir.path().join("edge_index.npy").exists());
1073 assert!(dir.path().join("load_graph.py").exists()); assert!(dir.path().join("metadata.json").exists());
1075 assert!(!dir.path().join("pickle_helper.py").exists());
1076 }
1077
1078 #[test]
1079 fn test_dgl_statistics() {
1080 let graph = create_test_graph();
1081 let dir = tempdir().unwrap();
1082
1083 let exporter = DGLExporter::new(DGLExportConfig::default());
1084 let metadata = exporter.export(&graph, dir.path()).unwrap();
1085
1086 assert!(metadata.statistics.contains_key("density"));
1088 assert!(metadata.statistics.contains_key("anomalous_node_ratio"));
1089 assert!(metadata.statistics.contains_key("anomalous_edge_ratio"));
1090
1091 let node_ratio = metadata.statistics.get("anomalous_node_ratio").unwrap();
1093 assert!((*node_ratio - 1.0 / 3.0).abs() < 0.01);
1094 }
1095}