use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use super::edges::{EdgeId, EdgeType, GraphEdge};
use super::nodes::{GraphNode, NodeId, NodeType};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Graph {
pub name: String,
pub graph_type: GraphType,
pub nodes: HashMap<NodeId, GraphNode>,
pub edges: HashMap<EdgeId, GraphEdge>,
pub adjacency: HashMap<NodeId, Vec<EdgeId>>,
pub reverse_adjacency: HashMap<NodeId, Vec<EdgeId>>,
pub nodes_by_type: HashMap<NodeType, Vec<NodeId>>,
pub edges_by_type: HashMap<EdgeType, Vec<EdgeId>>,
pub metadata: GraphMetadata,
next_node_id: NodeId,
next_edge_id: EdgeId,
}
impl Graph {
pub fn new(name: &str, graph_type: GraphType) -> Self {
Self {
name: name.to_string(),
graph_type,
nodes: HashMap::new(),
edges: HashMap::new(),
adjacency: HashMap::new(),
reverse_adjacency: HashMap::new(),
nodes_by_type: HashMap::new(),
edges_by_type: HashMap::new(),
metadata: GraphMetadata::default(),
next_node_id: 1,
next_edge_id: 1,
}
}
pub fn add_node(&mut self, mut node: GraphNode) -> NodeId {
let id = self.next_node_id;
self.next_node_id += 1;
node.id = id;
self.nodes_by_type
.entry(node.node_type.clone())
.or_default()
.push(id);
self.adjacency.insert(id, Vec::new());
self.reverse_adjacency.insert(id, Vec::new());
self.nodes.insert(id, node);
id
}
pub fn add_edge(&mut self, mut edge: GraphEdge) -> EdgeId {
let id = self.next_edge_id;
self.next_edge_id += 1;
edge.id = id;
self.adjacency.entry(edge.source).or_default().push(id);
self.reverse_adjacency
.entry(edge.target)
.or_default()
.push(id);
self.edges_by_type
.entry(edge.edge_type.clone())
.or_default()
.push(id);
self.edges.insert(id, edge);
id
}
pub fn get_node(&self, id: NodeId) -> Option<&GraphNode> {
self.nodes.get(&id)
}
pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut GraphNode> {
self.nodes.get_mut(&id)
}
pub fn get_edge(&self, id: EdgeId) -> Option<&GraphEdge> {
self.edges.get(&id)
}
pub fn get_edge_mut(&mut self, id: EdgeId) -> Option<&mut GraphEdge> {
self.edges.get_mut(&id)
}
pub fn nodes_of_type(&self, node_type: &NodeType) -> Vec<&GraphNode> {
self.nodes_by_type
.get(node_type)
.map(|ids| ids.iter().filter_map(|id| self.nodes.get(id)).collect())
.unwrap_or_default()
}
pub fn edges_of_type(&self, edge_type: &EdgeType) -> Vec<&GraphEdge> {
self.edges_by_type
.get(edge_type)
.map(|ids| ids.iter().filter_map(|id| self.edges.get(id)).collect())
.unwrap_or_default()
}
pub fn outgoing_edges(&self, node_id: NodeId) -> Vec<&GraphEdge> {
self.adjacency
.get(&node_id)
.map(|ids| ids.iter().filter_map(|id| self.edges.get(id)).collect())
.unwrap_or_default()
}
pub fn incoming_edges(&self, node_id: NodeId) -> Vec<&GraphEdge> {
self.reverse_adjacency
.get(&node_id)
.map(|ids| ids.iter().filter_map(|id| self.edges.get(id)).collect())
.unwrap_or_default()
}
pub fn neighbors(&self, node_id: NodeId) -> Vec<NodeId> {
let mut neighbors = HashSet::new();
if let Some(edges) = self.adjacency.get(&node_id) {
for edge_id in edges {
if let Some(edge) = self.edges.get(edge_id) {
neighbors.insert(edge.target);
}
}
}
if let Some(edges) = self.reverse_adjacency.get(&node_id) {
for edge_id in edges {
if let Some(edge) = self.edges.get(edge_id) {
neighbors.insert(edge.source);
}
}
}
neighbors.into_iter().collect()
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn out_degree(&self, node_id: NodeId) -> usize {
self.adjacency.get(&node_id).map(|e| e.len()).unwrap_or(0)
}
pub fn in_degree(&self, node_id: NodeId) -> usize {
self.reverse_adjacency
.get(&node_id)
.map(|e| e.len())
.unwrap_or(0)
}
pub fn degree(&self, node_id: NodeId) -> usize {
self.out_degree(node_id) + self.in_degree(node_id)
}
pub fn anomalous_nodes(&self) -> Vec<&GraphNode> {
self.nodes.values().filter(|n| n.is_anomaly).collect()
}
pub fn anomalous_edges(&self) -> Vec<&GraphEdge> {
self.edges.values().filter(|e| e.is_anomaly).collect()
}
pub fn compute_statistics(&mut self) {
self.metadata.node_count = self.nodes.len();
self.metadata.edge_count = self.edges.len();
self.metadata.node_type_counts = self
.nodes_by_type
.iter()
.map(|(t, ids)| (t.as_str().to_string(), ids.len()))
.collect();
self.metadata.edge_type_counts = self
.edges_by_type
.iter()
.map(|(t, ids)| (t.as_str().to_string(), ids.len()))
.collect();
self.metadata.anomalous_node_count = self.anomalous_nodes().len();
self.metadata.anomalous_edge_count = self.anomalous_edges().len();
if self.metadata.node_count > 1 {
let max_edges = self.metadata.node_count * (self.metadata.node_count - 1);
self.metadata.density = self.metadata.edge_count as f64 / max_edges as f64;
}
if let Some(node) = self.nodes.values().next() {
self.metadata.node_feature_dim = node.features.len();
}
if let Some(edge) = self.edges.values().next() {
self.metadata.edge_feature_dim = edge.features.len();
}
}
pub fn edge_index(&self) -> (Vec<NodeId>, Vec<NodeId>) {
let mut sources = Vec::with_capacity(self.edges.len());
let mut targets = Vec::with_capacity(self.edges.len());
for edge in self.edges.values() {
sources.push(edge.source);
targets.push(edge.target);
}
(sources, targets)
}
pub fn node_features(&self) -> Vec<Vec<f64>> {
let mut node_ids: Vec<_> = self.nodes.keys().copied().collect();
node_ids.sort();
node_ids
.iter()
.filter_map(|id| self.nodes.get(id))
.map(|n| n.features.clone())
.collect()
}
pub fn edge_features(&self) -> Vec<Vec<f64>> {
let mut edge_ids: Vec<_> = self.edges.keys().copied().collect();
edge_ids.sort();
edge_ids
.iter()
.filter_map(|id| self.edges.get(id))
.map(|e| e.features.clone())
.collect()
}
pub fn node_labels(&self) -> Vec<Vec<String>> {
let mut node_ids: Vec<_> = self.nodes.keys().copied().collect();
node_ids.sort();
node_ids
.iter()
.filter_map(|id| self.nodes.get(id))
.map(|n| n.labels.clone())
.collect()
}
pub fn edge_labels(&self) -> Vec<Vec<String>> {
let mut edge_ids: Vec<_> = self.edges.keys().copied().collect();
edge_ids.sort();
edge_ids
.iter()
.filter_map(|id| self.edges.get(id))
.map(|e| e.labels.clone())
.collect()
}
pub fn node_anomaly_mask(&self) -> Vec<bool> {
let mut node_ids: Vec<_> = self.nodes.keys().copied().collect();
node_ids.sort();
node_ids
.iter()
.filter_map(|id| self.nodes.get(id))
.map(|n| n.is_anomaly)
.collect()
}
pub fn edge_anomaly_mask(&self) -> Vec<bool> {
let mut edge_ids: Vec<_> = self.edges.keys().copied().collect();
edge_ids.sort();
edge_ids
.iter()
.filter_map(|id| self.edges.get(id))
.map(|e| e.is_anomaly)
.collect()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum GraphType {
Transaction,
Approval,
EntityRelationship,
Heterogeneous,
Custom(String),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GraphMetadata {
pub node_count: usize,
pub edge_count: usize,
pub node_type_counts: HashMap<String, usize>,
pub edge_type_counts: HashMap<String, usize>,
pub anomalous_node_count: usize,
pub anomalous_edge_count: usize,
pub density: f64,
pub node_feature_dim: usize,
pub edge_feature_dim: usize,
pub properties: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HeterogeneousGraph {
pub name: String,
pub relations: HashMap<(String, String, String), Graph>,
pub all_nodes: HashMap<String, Vec<NodeId>>,
pub metadata: GraphMetadata,
}
impl HeterogeneousGraph {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
relations: HashMap::new(),
all_nodes: HashMap::new(),
metadata: GraphMetadata::default(),
}
}
pub fn add_relation(
&mut self,
source_type: &str,
edge_type: &str,
target_type: &str,
graph: Graph,
) {
let key = (
source_type.to_string(),
edge_type.to_string(),
target_type.to_string(),
);
self.relations.insert(key, graph);
}
pub fn get_relation(
&self,
source_type: &str,
edge_type: &str,
target_type: &str,
) -> Option<&Graph> {
let key = (
source_type.to_string(),
edge_type.to_string(),
target_type.to_string(),
);
self.relations.get(&key)
}
pub fn relation_types(&self) -> Vec<(String, String, String)> {
self.relations.keys().cloned().collect()
}
pub fn compute_statistics(&mut self) {
let mut total_nodes = 0;
let mut total_edges = 0;
for graph in self.relations.values() {
total_nodes += graph.node_count();
total_edges += graph.edge_count();
}
self.metadata.node_count = total_nodes;
self.metadata.edge_count = total_edges;
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_graph_creation() {
let mut graph = Graph::new("test", GraphType::Transaction);
let node1 = GraphNode::new(0, NodeType::Account, "1000".to_string(), "Cash".to_string());
let node2 = GraphNode::new(0, NodeType::Account, "2000".to_string(), "AP".to_string());
let id1 = graph.add_node(node1);
let id2 = graph.add_node(node2);
let edge = GraphEdge::new(0, id1, id2, EdgeType::Transaction);
graph.add_edge(edge);
assert_eq!(graph.node_count(), 2);
assert_eq!(graph.edge_count(), 1);
}
#[test]
fn test_adjacency() {
let mut graph = Graph::new("test", GraphType::Transaction);
let n1 = graph.add_node(GraphNode::new(
0,
NodeType::Account,
"1".to_string(),
"A".to_string(),
));
let n2 = graph.add_node(GraphNode::new(
0,
NodeType::Account,
"2".to_string(),
"B".to_string(),
));
let n3 = graph.add_node(GraphNode::new(
0,
NodeType::Account,
"3".to_string(),
"C".to_string(),
));
graph.add_edge(GraphEdge::new(0, n1, n2, EdgeType::Transaction));
graph.add_edge(GraphEdge::new(0, n1, n3, EdgeType::Transaction));
graph.add_edge(GraphEdge::new(0, n2, n3, EdgeType::Transaction));
assert_eq!(graph.out_degree(n1), 2);
assert_eq!(graph.in_degree(n3), 2);
assert_eq!(graph.neighbors(n1).len(), 2);
}
#[test]
fn test_edge_index() {
let mut graph = Graph::new("test", GraphType::Transaction);
let n1 = graph.add_node(GraphNode::new(
0,
NodeType::Account,
"1".to_string(),
"A".to_string(),
));
let n2 = graph.add_node(GraphNode::new(
0,
NodeType::Account,
"2".to_string(),
"B".to_string(),
));
graph.add_edge(GraphEdge::new(0, n1, n2, EdgeType::Transaction));
let (sources, targets) = graph.edge_index();
assert_eq!(sources.len(), 1);
assert_eq!(targets.len(), 1);
assert_eq!(sources[0], n1);
assert_eq!(targets[0], n2);
}
}