1use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5
6use super::edges::{EdgeId, EdgeType, GraphEdge};
7use super::nodes::{GraphNode, NodeId, NodeType};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Graph {
12 pub name: String,
14 pub graph_type: GraphType,
16 pub nodes: HashMap<NodeId, GraphNode>,
18 pub edges: HashMap<EdgeId, GraphEdge>,
20 pub adjacency: HashMap<NodeId, Vec<EdgeId>>,
22 pub reverse_adjacency: HashMap<NodeId, Vec<EdgeId>>,
24 pub nodes_by_type: HashMap<NodeType, Vec<NodeId>>,
26 pub edges_by_type: HashMap<EdgeType, Vec<EdgeId>>,
28 pub metadata: GraphMetadata,
30 next_node_id: NodeId,
32 next_edge_id: EdgeId,
34}
35
36impl Graph {
37 pub fn new(name: &str, graph_type: GraphType) -> Self {
39 Self {
40 name: name.to_string(),
41 graph_type,
42 nodes: HashMap::new(),
43 edges: HashMap::new(),
44 adjacency: HashMap::new(),
45 reverse_adjacency: HashMap::new(),
46 nodes_by_type: HashMap::new(),
47 edges_by_type: HashMap::new(),
48 metadata: GraphMetadata::default(),
49 next_node_id: 1,
50 next_edge_id: 1,
51 }
52 }
53
54 pub fn add_node(&mut self, mut node: GraphNode) -> NodeId {
56 let id = self.next_node_id;
57 self.next_node_id += 1;
58 node.id = id;
59
60 self.nodes_by_type
62 .entry(node.node_type.clone())
63 .or_default()
64 .push(id);
65
66 self.adjacency.insert(id, Vec::new());
68 self.reverse_adjacency.insert(id, Vec::new());
69
70 self.nodes.insert(id, node);
71 id
72 }
73
74 pub fn add_edge(&mut self, mut edge: GraphEdge) -> EdgeId {
76 let id = self.next_edge_id;
77 self.next_edge_id += 1;
78 edge.id = id;
79
80 self.adjacency.entry(edge.source).or_default().push(id);
82 self.reverse_adjacency
83 .entry(edge.target)
84 .or_default()
85 .push(id);
86
87 self.edges_by_type
89 .entry(edge.edge_type.clone())
90 .or_default()
91 .push(id);
92
93 self.edges.insert(id, edge);
94 id
95 }
96
97 pub fn get_node(&self, id: NodeId) -> Option<&GraphNode> {
99 self.nodes.get(&id)
100 }
101
102 pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut GraphNode> {
104 self.nodes.get_mut(&id)
105 }
106
107 pub fn get_edge(&self, id: EdgeId) -> Option<&GraphEdge> {
109 self.edges.get(&id)
110 }
111
112 pub fn get_edge_mut(&mut self, id: EdgeId) -> Option<&mut GraphEdge> {
114 self.edges.get_mut(&id)
115 }
116
117 pub fn nodes_of_type(&self, node_type: &NodeType) -> Vec<&GraphNode> {
119 self.nodes_by_type
120 .get(node_type)
121 .map(|ids| ids.iter().filter_map(|id| self.nodes.get(id)).collect())
122 .unwrap_or_default()
123 }
124
125 pub fn edges_of_type(&self, edge_type: &EdgeType) -> Vec<&GraphEdge> {
127 self.edges_by_type
128 .get(edge_type)
129 .map(|ids| ids.iter().filter_map(|id| self.edges.get(id)).collect())
130 .unwrap_or_default()
131 }
132
133 pub fn outgoing_edges(&self, node_id: NodeId) -> Vec<&GraphEdge> {
135 self.adjacency
136 .get(&node_id)
137 .map(|ids| ids.iter().filter_map(|id| self.edges.get(id)).collect())
138 .unwrap_or_default()
139 }
140
141 pub fn incoming_edges(&self, node_id: NodeId) -> Vec<&GraphEdge> {
143 self.reverse_adjacency
144 .get(&node_id)
145 .map(|ids| ids.iter().filter_map(|id| self.edges.get(id)).collect())
146 .unwrap_or_default()
147 }
148
149 pub fn neighbors(&self, node_id: NodeId) -> Vec<NodeId> {
151 let mut neighbors = HashSet::new();
152
153 if let Some(edges) = self.adjacency.get(&node_id) {
155 for edge_id in edges {
156 if let Some(edge) = self.edges.get(edge_id) {
157 neighbors.insert(edge.target);
158 }
159 }
160 }
161
162 if let Some(edges) = self.reverse_adjacency.get(&node_id) {
164 for edge_id in edges {
165 if let Some(edge) = self.edges.get(edge_id) {
166 neighbors.insert(edge.source);
167 }
168 }
169 }
170
171 neighbors.into_iter().collect()
172 }
173
174 pub fn node_count(&self) -> usize {
176 self.nodes.len()
177 }
178
179 pub fn edge_count(&self) -> usize {
181 self.edges.len()
182 }
183
184 pub fn out_degree(&self, node_id: NodeId) -> usize {
186 self.adjacency.get(&node_id).map(|e| e.len()).unwrap_or(0)
187 }
188
189 pub fn in_degree(&self, node_id: NodeId) -> usize {
191 self.reverse_adjacency
192 .get(&node_id)
193 .map(|e| e.len())
194 .unwrap_or(0)
195 }
196
197 pub fn degree(&self, node_id: NodeId) -> usize {
199 self.out_degree(node_id) + self.in_degree(node_id)
200 }
201
202 pub fn anomalous_nodes(&self) -> Vec<&GraphNode> {
204 self.nodes.values().filter(|n| n.is_anomaly).collect()
205 }
206
207 pub fn anomalous_edges(&self) -> Vec<&GraphEdge> {
209 self.edges.values().filter(|e| e.is_anomaly).collect()
210 }
211
212 pub fn compute_statistics(&mut self) {
214 self.metadata.node_count = self.nodes.len();
215 self.metadata.edge_count = self.edges.len();
216
217 self.metadata.node_type_counts = self
219 .nodes_by_type
220 .iter()
221 .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
222 .collect();
223
224 self.metadata.edge_type_counts = self
225 .edges_by_type
226 .iter()
227 .map(|(t, ids)| (t.as_str().to_string(), ids.len()))
228 .collect();
229
230 self.metadata.anomalous_node_count = self.anomalous_nodes().len();
232 self.metadata.anomalous_edge_count = self.anomalous_edges().len();
233
234 if self.metadata.node_count > 1 {
236 let max_edges = self.metadata.node_count * (self.metadata.node_count - 1);
237 self.metadata.density = self.metadata.edge_count as f64 / max_edges as f64;
238 }
239
240 if let Some(node) = self.nodes.values().next() {
242 self.metadata.node_feature_dim = node.features.len();
243 }
244 if let Some(edge) = self.edges.values().next() {
245 self.metadata.edge_feature_dim = edge.features.len();
246 }
247 }
248
249 pub fn edge_index(&self) -> (Vec<NodeId>, Vec<NodeId>) {
251 let mut sources = Vec::with_capacity(self.edges.len());
252 let mut targets = Vec::with_capacity(self.edges.len());
253
254 for edge in self.edges.values() {
255 sources.push(edge.source);
256 targets.push(edge.target);
257 }
258
259 (sources, targets)
260 }
261
262 pub fn node_features(&self) -> Vec<Vec<f64>> {
264 let mut node_ids: Vec<_> = self.nodes.keys().copied().collect();
265 node_ids.sort();
266
267 node_ids
268 .iter()
269 .filter_map(|id| self.nodes.get(id))
270 .map(|n| n.features.clone())
271 .collect()
272 }
273
274 pub fn edge_features(&self) -> Vec<Vec<f64>> {
276 let mut edge_ids: Vec<_> = self.edges.keys().copied().collect();
277 edge_ids.sort();
278
279 edge_ids
280 .iter()
281 .filter_map(|id| self.edges.get(id))
282 .map(|e| e.features.clone())
283 .collect()
284 }
285
286 pub fn node_labels(&self) -> Vec<Vec<String>> {
288 let mut node_ids: Vec<_> = self.nodes.keys().copied().collect();
289 node_ids.sort();
290
291 node_ids
292 .iter()
293 .filter_map(|id| self.nodes.get(id))
294 .map(|n| n.labels.clone())
295 .collect()
296 }
297
298 pub fn edge_labels(&self) -> Vec<Vec<String>> {
300 let mut edge_ids: Vec<_> = self.edges.keys().copied().collect();
301 edge_ids.sort();
302
303 edge_ids
304 .iter()
305 .filter_map(|id| self.edges.get(id))
306 .map(|e| e.labels.clone())
307 .collect()
308 }
309
310 pub fn node_anomaly_mask(&self) -> Vec<bool> {
312 let mut node_ids: Vec<_> = self.nodes.keys().copied().collect();
313 node_ids.sort();
314
315 node_ids
316 .iter()
317 .filter_map(|id| self.nodes.get(id))
318 .map(|n| n.is_anomaly)
319 .collect()
320 }
321
322 pub fn edge_anomaly_mask(&self) -> Vec<bool> {
324 let mut edge_ids: Vec<_> = self.edges.keys().copied().collect();
325 edge_ids.sort();
326
327 edge_ids
328 .iter()
329 .filter_map(|id| self.edges.get(id))
330 .map(|e| e.is_anomaly)
331 .collect()
332 }
333}
334
335#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
337pub enum GraphType {
338 Transaction,
340 Approval,
342 EntityRelationship,
344 Heterogeneous,
346 Custom(String),
348}
349
350#[derive(Debug, Clone, Default, Serialize, Deserialize)]
352pub struct GraphMetadata {
353 pub node_count: usize,
355 pub edge_count: usize,
357 pub node_type_counts: HashMap<String, usize>,
359 pub edge_type_counts: HashMap<String, usize>,
361 pub anomalous_node_count: usize,
363 pub anomalous_edge_count: usize,
365 pub density: f64,
367 pub node_feature_dim: usize,
369 pub edge_feature_dim: usize,
371 pub properties: HashMap<String, String>,
373}
374
375#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct HeterogeneousGraph {
378 pub name: String,
380 pub relations: HashMap<(String, String, String), Graph>,
382 pub all_nodes: HashMap<String, Vec<NodeId>>,
384 pub metadata: GraphMetadata,
386}
387
388impl HeterogeneousGraph {
389 pub fn new(name: &str) -> Self {
391 Self {
392 name: name.to_string(),
393 relations: HashMap::new(),
394 all_nodes: HashMap::new(),
395 metadata: GraphMetadata::default(),
396 }
397 }
398
399 pub fn add_relation(
401 &mut self,
402 source_type: &str,
403 edge_type: &str,
404 target_type: &str,
405 graph: Graph,
406 ) {
407 let key = (
408 source_type.to_string(),
409 edge_type.to_string(),
410 target_type.to_string(),
411 );
412 self.relations.insert(key, graph);
413 }
414
415 pub fn get_relation(
417 &self,
418 source_type: &str,
419 edge_type: &str,
420 target_type: &str,
421 ) -> Option<&Graph> {
422 let key = (
423 source_type.to_string(),
424 edge_type.to_string(),
425 target_type.to_string(),
426 );
427 self.relations.get(&key)
428 }
429
430 pub fn relation_types(&self) -> Vec<(String, String, String)> {
432 self.relations.keys().cloned().collect()
433 }
434
435 pub fn compute_statistics(&mut self) {
437 let mut total_nodes = 0;
438 let mut total_edges = 0;
439
440 for graph in self.relations.values() {
441 total_nodes += graph.node_count();
442 total_edges += graph.edge_count();
443 }
444
445 self.metadata.node_count = total_nodes;
446 self.metadata.edge_count = total_edges;
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn test_graph_creation() {
456 let mut graph = Graph::new("test", GraphType::Transaction);
457
458 let node1 = GraphNode::new(0, NodeType::Account, "1000".to_string(), "Cash".to_string());
459 let node2 = GraphNode::new(0, NodeType::Account, "2000".to_string(), "AP".to_string());
460
461 let id1 = graph.add_node(node1);
462 let id2 = graph.add_node(node2);
463
464 let edge = GraphEdge::new(0, id1, id2, EdgeType::Transaction);
465 graph.add_edge(edge);
466
467 assert_eq!(graph.node_count(), 2);
468 assert_eq!(graph.edge_count(), 1);
469 }
470
471 #[test]
472 fn test_adjacency() {
473 let mut graph = Graph::new("test", GraphType::Transaction);
474
475 let n1 = graph.add_node(GraphNode::new(
476 0,
477 NodeType::Account,
478 "1".to_string(),
479 "A".to_string(),
480 ));
481 let n2 = graph.add_node(GraphNode::new(
482 0,
483 NodeType::Account,
484 "2".to_string(),
485 "B".to_string(),
486 ));
487 let n3 = graph.add_node(GraphNode::new(
488 0,
489 NodeType::Account,
490 "3".to_string(),
491 "C".to_string(),
492 ));
493
494 graph.add_edge(GraphEdge::new(0, n1, n2, EdgeType::Transaction));
495 graph.add_edge(GraphEdge::new(0, n1, n3, EdgeType::Transaction));
496 graph.add_edge(GraphEdge::new(0, n2, n3, EdgeType::Transaction));
497
498 assert_eq!(graph.out_degree(n1), 2);
499 assert_eq!(graph.in_degree(n3), 2);
500 assert_eq!(graph.neighbors(n1).len(), 2);
501 }
502
503 #[test]
504 fn test_edge_index() {
505 let mut graph = Graph::new("test", GraphType::Transaction);
506
507 let n1 = graph.add_node(GraphNode::new(
508 0,
509 NodeType::Account,
510 "1".to_string(),
511 "A".to_string(),
512 ));
513 let n2 = graph.add_node(GraphNode::new(
514 0,
515 NodeType::Account,
516 "2".to_string(),
517 "B".to_string(),
518 ));
519
520 graph.add_edge(GraphEdge::new(0, n1, n2, EdgeType::Transaction));
521
522 let (sources, targets) = graph.edge_index();
523 assert_eq!(sources.len(), 1);
524 assert_eq!(targets.len(), 1);
525 assert_eq!(sources[0], n1);
526 assert_eq!(targets[0], n2);
527 }
528}