use std::collections::{HashMap, HashSet, VecDeque};
use serde::{Deserialize, Serialize};
use crate::models::{Graph, NodeId};
use crate::EdgeType;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum GroupType {
RelatedParty,
VendorRing,
CustomerCluster,
MuleNetwork,
Intercompany,
ApprovalChain,
TransactionCluster,
Custom(String),
}
impl GroupType {
pub fn name(&self) -> &str {
match self {
GroupType::RelatedParty => "related_party",
GroupType::VendorRing => "vendor_ring",
GroupType::CustomerCluster => "customer_cluster",
GroupType::MuleNetwork => "mule_network",
GroupType::Intercompany => "intercompany",
GroupType::ApprovalChain => "approval_chain",
GroupType::TransactionCluster => "transaction_cluster",
GroupType::Custom(s) => s.as_str(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum GroupDetectionAlgorithm {
ConnectedComponents,
LabelPropagation,
DenseSubgraph,
CliqueDetection,
}
#[derive(Debug, Clone)]
pub struct GroupDetectionConfig {
pub min_group_size: usize,
pub max_group_size: usize,
pub min_cohesion: f64,
pub algorithms: Vec<GroupDetectionAlgorithm>,
pub max_groups: usize,
pub classify_types: bool,
pub edge_types: Option<Vec<EdgeType>>,
}
impl Default for GroupDetectionConfig {
fn default() -> Self {
Self {
min_group_size: 3,
max_group_size: 50,
min_cohesion: 0.1,
algorithms: vec![GroupDetectionAlgorithm::ConnectedComponents],
max_groups: 1000,
classify_types: true,
edge_types: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EntityGroup {
pub group_id: u64,
pub members: Vec<NodeId>,
pub group_type: GroupType,
pub confidence: f64,
pub hub_node: Option<NodeId>,
pub internal_volume: f64,
pub external_volume: f64,
pub cohesion: f64,
}
impl EntityGroup {
pub fn new(group_id: u64, members: Vec<NodeId>, group_type: GroupType) -> Self {
Self {
group_id,
members,
group_type,
confidence: 1.0,
hub_node: None,
internal_volume: 0.0,
external_volume: 0.0,
cohesion: 0.0,
}
}
pub fn with_hub(mut self, hub: NodeId) -> Self {
self.hub_node = Some(hub);
self
}
pub fn with_volumes(mut self, internal: f64, external: f64) -> Self {
self.internal_volume = internal;
self.external_volume = external;
self
}
pub fn with_cohesion(mut self, cohesion: f64) -> Self {
self.cohesion = cohesion;
self
}
pub fn size(&self) -> usize {
self.members.len()
}
pub fn contains(&self, node_id: NodeId) -> bool {
self.members.contains(&node_id)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GroupDetectionResult {
pub groups: Vec<EntityGroup>,
pub node_groups: HashMap<NodeId, Vec<u64>>,
pub total_groups: usize,
pub groups_by_type: HashMap<String, usize>,
}
impl GroupDetectionResult {
pub fn groups_for_node(&self, node_id: NodeId) -> Vec<&EntityGroup> {
self.node_groups
.get(&node_id)
.map(|ids| {
ids.iter()
.filter_map(|&id| self.groups.iter().find(|g| g.group_id == id))
.collect()
})
.unwrap_or_default()
}
pub fn node_features(&self, node_id: NodeId) -> Vec<f64> {
let groups = self.groups_for_node(node_id);
let group_count = groups.len() as f64;
let max_group_size = groups.iter().map(|g| g.size()).max().unwrap_or(0) as f64;
let is_hub = groups.iter().any(|g| g.hub_node == Some(node_id));
vec![group_count, max_group_size, if is_hub { 1.0 } else { 0.0 }]
}
pub fn feature_dim() -> usize {
3
}
}
pub fn detect_entity_groups(graph: &Graph, config: &GroupDetectionConfig) -> GroupDetectionResult {
let mut all_groups = Vec::new();
let mut next_group_id = 1u64;
for algorithm in &config.algorithms {
let groups = match algorithm {
GroupDetectionAlgorithm::ConnectedComponents => {
detect_connected_components(graph, config, &mut next_group_id)
}
GroupDetectionAlgorithm::LabelPropagation => {
detect_label_propagation(graph, config, &mut next_group_id)
}
GroupDetectionAlgorithm::DenseSubgraph => {
detect_dense_subgraphs(graph, config, &mut next_group_id)
}
GroupDetectionAlgorithm::CliqueDetection => {
detect_cliques(graph, config, &mut next_group_id)
}
};
all_groups.extend(groups);
if all_groups.len() >= config.max_groups {
all_groups.truncate(config.max_groups);
break;
}
}
let mut node_groups: HashMap<NodeId, Vec<u64>> = HashMap::new();
for group in &all_groups {
for &member in &group.members {
node_groups.entry(member).or_default().push(group.group_id);
}
}
let mut groups_by_type: HashMap<String, usize> = HashMap::new();
for group in &all_groups {
*groups_by_type
.entry(group.group_type.name().to_string())
.or_insert(0) += 1;
}
GroupDetectionResult {
total_groups: all_groups.len(),
groups: all_groups,
node_groups,
groups_by_type,
}
}
fn detect_connected_components(
graph: &Graph,
config: &GroupDetectionConfig,
next_id: &mut u64,
) -> Vec<EntityGroup> {
let mut groups = Vec::new();
let mut visited: HashSet<NodeId> = HashSet::new();
for &start_node in graph.nodes.keys() {
if visited.contains(&start_node) {
continue;
}
let mut component = Vec::new();
let mut queue = VecDeque::new();
queue.push_back(start_node);
visited.insert(start_node);
while let Some(node) = queue.pop_front() {
component.push(node);
for neighbor in graph.neighbors(node) {
if !visited.contains(&neighbor) {
visited.insert(neighbor);
queue.push_back(neighbor);
}
}
if component.len() >= config.max_group_size {
break;
}
}
if component.len() >= config.min_group_size && component.len() <= config.max_group_size {
let group_type = if config.classify_types {
classify_group_type(graph, &component)
} else {
GroupType::TransactionCluster
};
let mut group = EntityGroup::new(*next_id, component.clone(), group_type);
*next_id += 1;
let (internal, external, cohesion) = calculate_group_metrics(graph, &component);
if cohesion >= config.min_cohesion {
let hub = find_hub_node(graph, &component);
group = group
.with_hub(hub)
.with_volumes(internal, external)
.with_cohesion(cohesion);
groups.push(group);
}
}
}
groups
}
fn detect_label_propagation(
graph: &Graph,
config: &GroupDetectionConfig,
next_id: &mut u64,
) -> Vec<EntityGroup> {
let nodes: Vec<NodeId> = graph.nodes.keys().copied().collect();
if nodes.is_empty() {
return Vec::new();
}
let mut labels: HashMap<NodeId, u64> = nodes
.iter()
.enumerate()
.map(|(i, &n)| (n, i as u64))
.collect();
for _ in 0..10 {
let mut changed = false;
for &node in &nodes {
let neighbors = graph.neighbors(node);
if neighbors.is_empty() {
continue;
}
let mut label_counts: HashMap<u64, usize> = HashMap::new();
for neighbor in neighbors {
if let Some(&label) = labels.get(&neighbor) {
*label_counts.entry(label).or_insert(0) += 1;
}
}
if let Some((&most_common, _)) = label_counts.iter().max_by_key(|(_, &count)| count) {
if labels.get(&node) != Some(&most_common) {
labels.insert(node, most_common);
changed = true;
}
}
}
if !changed {
break;
}
}
let mut communities: HashMap<u64, Vec<NodeId>> = HashMap::new();
for (node, label) in labels {
communities.entry(label).or_default().push(node);
}
let mut groups = Vec::new();
for (_, members) in communities {
if members.len() >= config.min_group_size && members.len() <= config.max_group_size {
let group_type = if config.classify_types {
classify_group_type(graph, &members)
} else {
GroupType::TransactionCluster
};
let (internal, external, cohesion) = calculate_group_metrics(graph, &members);
if cohesion >= config.min_cohesion {
let hub = find_hub_node(graph, &members);
let group = EntityGroup::new(*next_id, members, group_type)
.with_hub(hub)
.with_volumes(internal, external)
.with_cohesion(cohesion);
*next_id += 1;
groups.push(group);
}
}
}
groups
}
fn detect_dense_subgraphs(
graph: &Graph,
config: &GroupDetectionConfig,
next_id: &mut u64,
) -> Vec<EntityGroup> {
let mut groups = Vec::new();
let mut nodes_by_degree: Vec<(NodeId, usize)> =
graph.nodes.keys().map(|&n| (n, graph.degree(n))).collect();
nodes_by_degree.sort_by_key(|(_, d)| std::cmp::Reverse(*d));
let mut used_nodes: HashSet<NodeId> = HashSet::new();
for (seed, _) in nodes_by_degree {
if used_nodes.contains(&seed) {
continue;
}
let mut subgraph = vec![seed];
let mut candidates: HashSet<NodeId> = graph.neighbors(seed).into_iter().collect();
while subgraph.len() < config.max_group_size && !candidates.is_empty() {
let best_candidate = candidates
.iter()
.map(|&c| {
let connections = graph
.neighbors(c)
.iter()
.filter(|n| subgraph.contains(n))
.count();
(c, connections)
})
.max_by_key(|(_, conn)| *conn);
match best_candidate {
Some((c, conn)) if conn > 0 => {
subgraph.push(c);
candidates.remove(&c);
for neighbor in graph.neighbors(c) {
if !subgraph.contains(&neighbor) && !used_nodes.contains(&neighbor) {
candidates.insert(neighbor);
}
}
}
_ => break,
}
let (_, _, cohesion) = calculate_group_metrics(graph, &subgraph);
if cohesion < config.min_cohesion * 2.0 {
break;
}
}
if subgraph.len() >= config.min_group_size {
used_nodes.extend(&subgraph);
let group_type = if config.classify_types {
classify_group_type(graph, &subgraph)
} else {
GroupType::TransactionCluster
};
let (internal, external, cohesion) = calculate_group_metrics(graph, &subgraph);
let hub = find_hub_node(graph, &subgraph);
let group = EntityGroup::new(*next_id, subgraph, group_type)
.with_hub(hub)
.with_volumes(internal, external)
.with_cohesion(cohesion);
*next_id += 1;
groups.push(group);
if groups.len() >= config.max_groups {
break;
}
}
}
groups
}
fn detect_cliques(
graph: &Graph,
config: &GroupDetectionConfig,
next_id: &mut u64,
) -> Vec<EntityGroup> {
let mut groups = Vec::new();
let mut seen_cliques: HashSet<Vec<NodeId>> = HashSet::new();
let mut adjacency: HashMap<NodeId, HashSet<NodeId>> = HashMap::new();
for edge in graph.edges.values() {
adjacency
.entry(edge.source)
.or_default()
.insert(edge.target);
adjacency
.entry(edge.target)
.or_default()
.insert(edge.source);
}
let nodes: Vec<NodeId> = graph.nodes.keys().copied().collect();
for &a in &nodes {
if groups.len() >= config.max_groups {
break;
}
let neighbors_a = match adjacency.get(&a) {
Some(n) => n,
None => continue,
};
for &b in neighbors_a {
if b <= a {
continue;
}
let neighbors_b = match adjacency.get(&b) {
Some(n) => n,
None => continue,
};
for &c in neighbors_a {
if c <= b {
continue;
}
if neighbors_b.contains(&c) {
let mut clique = vec![a, b, c];
clique.sort();
if !seen_cliques.contains(&clique) && clique.len() >= config.min_group_size {
seen_cliques.insert(clique.clone());
let group_type = if config.classify_types {
classify_group_type(graph, &clique)
} else {
GroupType::TransactionCluster
};
let (internal, external, cohesion) =
calculate_group_metrics(graph, &clique);
let hub = find_hub_node(graph, &clique);
let group = EntityGroup::new(*next_id, clique, group_type)
.with_hub(hub)
.with_volumes(internal, external)
.with_cohesion(cohesion);
*next_id += 1;
groups.push(group);
}
}
}
}
}
groups
}
fn classify_group_type(graph: &Graph, members: &[NodeId]) -> GroupType {
let member_set: HashSet<NodeId> = members.iter().copied().collect();
let has_cycles = members.iter().any(|&node| {
graph
.outgoing_edges(node)
.iter()
.any(|e| member_set.contains(&e.target))
&& graph
.incoming_edges(node)
.iter()
.any(|e| member_set.contains(&e.source))
});
let has_ownership = graph.edges.values().any(|e| {
member_set.contains(&e.source)
&& member_set.contains(&e.target)
&& matches!(e.edge_type, EdgeType::Ownership | EdgeType::Intercompany)
});
let has_approval = graph.edges.values().any(|e| {
member_set.contains(&e.source)
&& member_set.contains(&e.target)
&& matches!(e.edge_type, EdgeType::Approval | EdgeType::ReportsTo)
});
let anomalous_nodes = members
.iter()
.filter(|&&n| {
graph
.get_node(n)
.map(|node| node.is_anomaly)
.unwrap_or(false)
})
.count();
let anomaly_rate = anomalous_nodes as f64 / members.len() as f64;
if has_ownership {
GroupType::Intercompany
} else if has_approval {
GroupType::ApprovalChain
} else if has_cycles && anomaly_rate > 0.5 {
GroupType::MuleNetwork
} else if has_cycles {
GroupType::VendorRing
} else if anomaly_rate > 0.3 {
GroupType::MuleNetwork
} else {
GroupType::TransactionCluster
}
}
fn calculate_group_metrics(graph: &Graph, members: &[NodeId]) -> (f64, f64, f64) {
let member_set: HashSet<NodeId> = members.iter().copied().collect();
let mut internal_volume = 0.0;
let mut external_volume = 0.0;
let mut internal_edges = 0;
for &member in members {
for edge in graph.outgoing_edges(member) {
if member_set.contains(&edge.target) {
internal_volume += edge.weight;
internal_edges += 1;
} else {
external_volume += edge.weight;
}
}
for edge in graph.incoming_edges(member) {
if !member_set.contains(&edge.source) {
external_volume += edge.weight;
}
}
}
let max_possible_edges = members.len() * (members.len() - 1);
let cohesion = if max_possible_edges > 0 {
internal_edges as f64 / max_possible_edges as f64
} else {
0.0
};
(internal_volume, external_volume, cohesion)
}
fn find_hub_node(graph: &Graph, members: &[NodeId]) -> NodeId {
let member_set: HashSet<NodeId> = members.iter().copied().collect();
members
.iter()
.map(|&n| {
let internal_degree = graph
.neighbors(n)
.iter()
.filter(|neighbor| member_set.contains(neighbor))
.count();
(n, internal_degree)
})
.max_by_key(|(_, degree)| *degree)
.map(|(n, _)| n)
.unwrap_or(members[0])
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::models::{GraphEdge, GraphNode, GraphType, NodeType};
use crate::test_helpers::create_entity_group_test_graph;
#[test]
fn test_connected_components() {
let graph = create_entity_group_test_graph();
let config = GroupDetectionConfig::default();
let result = detect_entity_groups(&graph, &config);
assert!(result.total_groups >= 1);
}
#[test]
fn test_label_propagation() {
let graph = create_entity_group_test_graph();
let config = GroupDetectionConfig {
algorithms: vec![GroupDetectionAlgorithm::LabelPropagation],
..Default::default()
};
let result = detect_entity_groups(&graph, &config);
assert!(!result.groups.is_empty() || result.total_groups == 0);
}
#[test]
fn test_clique_detection() {
let graph = create_entity_group_test_graph();
let config = GroupDetectionConfig {
algorithms: vec![GroupDetectionAlgorithm::CliqueDetection],
min_cohesion: 0.1, ..Default::default()
};
let result = detect_entity_groups(&graph, &config);
let cliques: Vec<_> = result.groups.iter().filter(|g| g.cohesion > 0.4).collect();
assert!(!cliques.is_empty());
}
#[test]
fn test_node_features() {
let graph = create_entity_group_test_graph();
let config = GroupDetectionConfig::default();
let result = detect_entity_groups(&graph, &config);
let features = result.node_features(1);
assert_eq!(features.len(), GroupDetectionResult::feature_dim());
}
#[test]
fn test_group_metrics() {
let graph = create_entity_group_test_graph();
let members = vec![1, 2, 3];
let (internal, _external, cohesion) = calculate_group_metrics(&graph, &members);
assert!(internal > 0.0);
assert!(cohesion > 0.0);
}
#[test]
fn test_hub_detection() {
let mut graph = Graph::new("test", GraphType::Transaction);
let n1 = graph.add_node(GraphNode::new(
0,
NodeType::Account,
"Hub".to_string(),
"Hub".to_string(),
));
let n2 = graph.add_node(GraphNode::new(
0,
NodeType::Account,
"A".to_string(),
"A".to_string(),
));
let n3 = graph.add_node(GraphNode::new(
0,
NodeType::Account,
"B".to_string(),
"B".to_string(),
));
let n4 = graph.add_node(GraphNode::new(
0,
NodeType::Account,
"C".to_string(),
"C".to_string(),
));
graph.add_edge(GraphEdge::new(0, n1, n2, EdgeType::Transaction));
graph.add_edge(GraphEdge::new(0, n1, n3, EdgeType::Transaction));
graph.add_edge(GraphEdge::new(0, n1, n4, EdgeType::Transaction));
let members = vec![n1, n2, n3, n4];
let hub = find_hub_node(&graph, &members);
assert_eq!(hub, n1);
}
}