1use std::collections::{HashMap, HashSet, VecDeque};
10
11use serde::{Deserialize, Serialize};
12
13use crate::models::{Graph, NodeId};
14use crate::EdgeType;
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
18pub enum GroupType {
19 RelatedParty,
21 VendorRing,
23 CustomerCluster,
25 MuleNetwork,
27 Intercompany,
29 ApprovalChain,
31 TransactionCluster,
33 Custom(String),
35}
36
37impl GroupType {
38 pub fn name(&self) -> &str {
40 match self {
41 GroupType::RelatedParty => "related_party",
42 GroupType::VendorRing => "vendor_ring",
43 GroupType::CustomerCluster => "customer_cluster",
44 GroupType::MuleNetwork => "mule_network",
45 GroupType::Intercompany => "intercompany",
46 GroupType::ApprovalChain => "approval_chain",
47 GroupType::TransactionCluster => "transaction_cluster",
48 GroupType::Custom(s) => s.as_str(),
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55pub enum GroupDetectionAlgorithm {
56 ConnectedComponents,
58 LabelPropagation,
60 DenseSubgraph,
62 CliqueDetection,
64}
65
66#[derive(Debug, Clone)]
68pub struct GroupDetectionConfig {
69 pub min_group_size: usize,
71 pub max_group_size: usize,
73 pub min_cohesion: f64,
75 pub algorithms: Vec<GroupDetectionAlgorithm>,
77 pub max_groups: usize,
79 pub classify_types: bool,
81 pub edge_types: Option<Vec<EdgeType>>,
83}
84
85impl Default for GroupDetectionConfig {
86 fn default() -> Self {
87 Self {
88 min_group_size: 3,
89 max_group_size: 50,
90 min_cohesion: 0.1,
91 algorithms: vec![GroupDetectionAlgorithm::ConnectedComponents],
92 max_groups: 1000,
93 classify_types: true,
94 edge_types: None,
95 }
96 }
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct EntityGroup {
102 pub group_id: u64,
104 pub members: Vec<NodeId>,
106 pub group_type: GroupType,
108 pub confidence: f64,
110 pub hub_node: Option<NodeId>,
112 pub internal_volume: f64,
114 pub external_volume: f64,
116 pub cohesion: f64,
118}
119
120impl EntityGroup {
121 pub fn new(group_id: u64, members: Vec<NodeId>, group_type: GroupType) -> Self {
123 Self {
124 group_id,
125 members,
126 group_type,
127 confidence: 1.0,
128 hub_node: None,
129 internal_volume: 0.0,
130 external_volume: 0.0,
131 cohesion: 0.0,
132 }
133 }
134
135 pub fn with_hub(mut self, hub: NodeId) -> Self {
137 self.hub_node = Some(hub);
138 self
139 }
140
141 pub fn with_volumes(mut self, internal: f64, external: f64) -> Self {
143 self.internal_volume = internal;
144 self.external_volume = external;
145 self
146 }
147
148 pub fn with_cohesion(mut self, cohesion: f64) -> Self {
150 self.cohesion = cohesion;
151 self
152 }
153
154 pub fn size(&self) -> usize {
156 self.members.len()
157 }
158
159 pub fn contains(&self, node_id: NodeId) -> bool {
161 self.members.contains(&node_id)
162 }
163}
164
165#[derive(Debug, Clone, Default, Serialize, Deserialize)]
167pub struct GroupDetectionResult {
168 pub groups: Vec<EntityGroup>,
170 pub node_groups: HashMap<NodeId, Vec<u64>>,
172 pub total_groups: usize,
174 pub groups_by_type: HashMap<String, usize>,
176}
177
178impl GroupDetectionResult {
179 pub fn groups_for_node(&self, node_id: NodeId) -> Vec<&EntityGroup> {
181 self.node_groups
182 .get(&node_id)
183 .map(|ids| {
184 ids.iter()
185 .filter_map(|&id| self.groups.iter().find(|g| g.group_id == id))
186 .collect()
187 })
188 .unwrap_or_default()
189 }
190
191 pub fn node_features(&self, node_id: NodeId) -> Vec<f64> {
193 let groups = self.groups_for_node(node_id);
194
195 let group_count = groups.len() as f64;
196 let max_group_size = groups.iter().map(|g| g.size()).max().unwrap_or(0) as f64;
197 let is_hub = groups.iter().any(|g| g.hub_node == Some(node_id));
198
199 vec![group_count, max_group_size, if is_hub { 1.0 } else { 0.0 }]
200 }
201
202 pub fn feature_dim() -> usize {
204 3
205 }
206}
207
208pub fn detect_entity_groups(graph: &Graph, config: &GroupDetectionConfig) -> GroupDetectionResult {
210 let mut all_groups = Vec::new();
211 let mut next_group_id = 1u64;
212
213 for algorithm in &config.algorithms {
214 let groups = match algorithm {
215 GroupDetectionAlgorithm::ConnectedComponents => {
216 detect_connected_components(graph, config, &mut next_group_id)
217 }
218 GroupDetectionAlgorithm::LabelPropagation => {
219 detect_label_propagation(graph, config, &mut next_group_id)
220 }
221 GroupDetectionAlgorithm::DenseSubgraph => {
222 detect_dense_subgraphs(graph, config, &mut next_group_id)
223 }
224 GroupDetectionAlgorithm::CliqueDetection => {
225 detect_cliques(graph, config, &mut next_group_id)
226 }
227 };
228
229 all_groups.extend(groups);
230
231 if all_groups.len() >= config.max_groups {
232 all_groups.truncate(config.max_groups);
233 break;
234 }
235 }
236
237 let mut node_groups: HashMap<NodeId, Vec<u64>> = HashMap::new();
239 for group in &all_groups {
240 for &member in &group.members {
241 node_groups.entry(member).or_default().push(group.group_id);
242 }
243 }
244
245 let mut groups_by_type: HashMap<String, usize> = HashMap::new();
247 for group in &all_groups {
248 *groups_by_type
249 .entry(group.group_type.name().to_string())
250 .or_insert(0) += 1;
251 }
252
253 GroupDetectionResult {
254 total_groups: all_groups.len(),
255 groups: all_groups,
256 node_groups,
257 groups_by_type,
258 }
259}
260
261fn detect_connected_components(
263 graph: &Graph,
264 config: &GroupDetectionConfig,
265 next_id: &mut u64,
266) -> Vec<EntityGroup> {
267 let mut groups = Vec::new();
268 let mut visited: HashSet<NodeId> = HashSet::new();
269
270 for &start_node in graph.nodes.keys() {
271 if visited.contains(&start_node) {
272 continue;
273 }
274
275 let mut component = Vec::new();
277 let mut queue = VecDeque::new();
278 queue.push_back(start_node);
279 visited.insert(start_node);
280
281 while let Some(node) = queue.pop_front() {
282 component.push(node);
283
284 for neighbor in graph.neighbors(node) {
286 if !visited.contains(&neighbor) {
287 visited.insert(neighbor);
288 queue.push_back(neighbor);
289 }
290 }
291
292 if component.len() >= config.max_group_size {
294 break;
295 }
296 }
297
298 if component.len() >= config.min_group_size && component.len() <= config.max_group_size {
300 let group_type = if config.classify_types {
301 classify_group_type(graph, &component)
302 } else {
303 GroupType::TransactionCluster
304 };
305
306 let mut group = EntityGroup::new(*next_id, component.clone(), group_type);
307 *next_id += 1;
308
309 let (internal, external, cohesion) = calculate_group_metrics(graph, &component);
311 if cohesion >= config.min_cohesion {
312 let hub = find_hub_node(graph, &component);
313 group = group
314 .with_hub(hub)
315 .with_volumes(internal, external)
316 .with_cohesion(cohesion);
317 groups.push(group);
318 }
319 }
320 }
321
322 groups
323}
324
325fn detect_label_propagation(
327 graph: &Graph,
328 config: &GroupDetectionConfig,
329 next_id: &mut u64,
330) -> Vec<EntityGroup> {
331 let nodes: Vec<NodeId> = graph.nodes.keys().copied().collect();
332 if nodes.is_empty() {
333 return Vec::new();
334 }
335
336 let mut labels: HashMap<NodeId, u64> = nodes
338 .iter()
339 .enumerate()
340 .map(|(i, &n)| (n, i as u64))
341 .collect();
342
343 for _ in 0..10 {
345 let mut changed = false;
347
348 for &node in &nodes {
349 let neighbors = graph.neighbors(node);
350 if neighbors.is_empty() {
351 continue;
352 }
353
354 let mut label_counts: HashMap<u64, usize> = HashMap::new();
356 for neighbor in neighbors {
357 if let Some(&label) = labels.get(&neighbor) {
358 *label_counts.entry(label).or_insert(0) += 1;
359 }
360 }
361
362 if let Some((&most_common, _)) = label_counts.iter().max_by_key(|(_, &count)| count) {
364 if labels.get(&node) != Some(&most_common) {
365 labels.insert(node, most_common);
366 changed = true;
367 }
368 }
369 }
370
371 if !changed {
372 break;
373 }
374 }
375
376 let mut communities: HashMap<u64, Vec<NodeId>> = HashMap::new();
378 for (node, label) in labels {
379 communities.entry(label).or_default().push(node);
380 }
381
382 let mut groups = Vec::new();
384 for (_, members) in communities {
385 if members.len() >= config.min_group_size && members.len() <= config.max_group_size {
386 let group_type = if config.classify_types {
387 classify_group_type(graph, &members)
388 } else {
389 GroupType::TransactionCluster
390 };
391
392 let (internal, external, cohesion) = calculate_group_metrics(graph, &members);
393 if cohesion >= config.min_cohesion {
394 let hub = find_hub_node(graph, &members);
395 let group = EntityGroup::new(*next_id, members, group_type)
396 .with_hub(hub)
397 .with_volumes(internal, external)
398 .with_cohesion(cohesion);
399 *next_id += 1;
400 groups.push(group);
401 }
402 }
403 }
404
405 groups
406}
407
408fn detect_dense_subgraphs(
410 graph: &Graph,
411 config: &GroupDetectionConfig,
412 next_id: &mut u64,
413) -> Vec<EntityGroup> {
414 let mut groups = Vec::new();
415
416 let mut nodes_by_degree: Vec<(NodeId, usize)> =
418 graph.nodes.keys().map(|&n| (n, graph.degree(n))).collect();
419 nodes_by_degree.sort_by_key(|(_, d)| std::cmp::Reverse(*d));
420
421 let mut used_nodes: HashSet<NodeId> = HashSet::new();
422
423 for (seed, _) in nodes_by_degree {
424 if used_nodes.contains(&seed) {
425 continue;
426 }
427
428 let mut subgraph = vec![seed];
430 let mut candidates: HashSet<NodeId> = graph.neighbors(seed).into_iter().collect();
431
432 while subgraph.len() < config.max_group_size && !candidates.is_empty() {
433 let best_candidate = candidates
435 .iter()
436 .map(|&c| {
437 let connections = graph
438 .neighbors(c)
439 .iter()
440 .filter(|n| subgraph.contains(n))
441 .count();
442 (c, connections)
443 })
444 .max_by_key(|(_, conn)| *conn);
445
446 match best_candidate {
447 Some((c, conn)) if conn > 0 => {
448 subgraph.push(c);
449 candidates.remove(&c);
450
451 for neighbor in graph.neighbors(c) {
453 if !subgraph.contains(&neighbor) && !used_nodes.contains(&neighbor) {
454 candidates.insert(neighbor);
455 }
456 }
457 }
458 _ => break,
459 }
460
461 let (_, _, cohesion) = calculate_group_metrics(graph, &subgraph);
463 if cohesion < config.min_cohesion * 2.0 {
464 break;
466 }
467 }
468
469 if subgraph.len() >= config.min_group_size {
470 used_nodes.extend(&subgraph);
471
472 let group_type = if config.classify_types {
473 classify_group_type(graph, &subgraph)
474 } else {
475 GroupType::TransactionCluster
476 };
477
478 let (internal, external, cohesion) = calculate_group_metrics(graph, &subgraph);
479 let hub = find_hub_node(graph, &subgraph);
480
481 let group = EntityGroup::new(*next_id, subgraph, group_type)
482 .with_hub(hub)
483 .with_volumes(internal, external)
484 .with_cohesion(cohesion);
485 *next_id += 1;
486 groups.push(group);
487
488 if groups.len() >= config.max_groups {
489 break;
490 }
491 }
492 }
493
494 groups
495}
496
497fn detect_cliques(
499 graph: &Graph,
500 config: &GroupDetectionConfig,
501 next_id: &mut u64,
502) -> Vec<EntityGroup> {
503 let mut groups = Vec::new();
504 let mut seen_cliques: HashSet<Vec<NodeId>> = HashSet::new();
505
506 let mut adjacency: HashMap<NodeId, HashSet<NodeId>> = HashMap::new();
508 for edge in graph.edges.values() {
509 adjacency
510 .entry(edge.source)
511 .or_default()
512 .insert(edge.target);
513 adjacency
514 .entry(edge.target)
515 .or_default()
516 .insert(edge.source);
517 }
518
519 let nodes: Vec<NodeId> = graph.nodes.keys().copied().collect();
521
522 for &a in &nodes {
523 if groups.len() >= config.max_groups {
524 break;
525 }
526 let neighbors_a = match adjacency.get(&a) {
527 Some(n) => n,
528 None => continue,
529 };
530
531 for &b in neighbors_a {
532 if b <= a {
533 continue;
534 }
535
536 let neighbors_b = match adjacency.get(&b) {
537 Some(n) => n,
538 None => continue,
539 };
540
541 for &c in neighbors_a {
542 if c <= b {
543 continue;
544 }
545
546 if neighbors_b.contains(&c) {
547 let mut clique = vec![a, b, c];
548 clique.sort();
549
550 if !seen_cliques.contains(&clique) && clique.len() >= config.min_group_size {
551 seen_cliques.insert(clique.clone());
552
553 let group_type = if config.classify_types {
554 classify_group_type(graph, &clique)
555 } else {
556 GroupType::TransactionCluster
557 };
558
559 let (internal, external, cohesion) =
560 calculate_group_metrics(graph, &clique);
561 let hub = find_hub_node(graph, &clique);
562
563 let group = EntityGroup::new(*next_id, clique, group_type)
564 .with_hub(hub)
565 .with_volumes(internal, external)
566 .with_cohesion(cohesion);
567 *next_id += 1;
568 groups.push(group);
569 }
570 }
571 }
572 }
573 }
574
575 groups
576}
577
578fn classify_group_type(graph: &Graph, members: &[NodeId]) -> GroupType {
580 let member_set: HashSet<NodeId> = members.iter().copied().collect();
581
582 let has_cycles = members.iter().any(|&node| {
584 graph
585 .outgoing_edges(node)
586 .iter()
587 .any(|e| member_set.contains(&e.target))
588 && graph
589 .incoming_edges(node)
590 .iter()
591 .any(|e| member_set.contains(&e.source))
592 });
593
594 let has_ownership = graph.edges.values().any(|e| {
596 member_set.contains(&e.source)
597 && member_set.contains(&e.target)
598 && matches!(e.edge_type, EdgeType::Ownership | EdgeType::Intercompany)
599 });
600
601 let has_approval = graph.edges.values().any(|e| {
603 member_set.contains(&e.source)
604 && member_set.contains(&e.target)
605 && matches!(e.edge_type, EdgeType::Approval | EdgeType::ReportsTo)
606 });
607
608 let anomalous_nodes = members
610 .iter()
611 .filter(|&&n| {
612 graph
613 .get_node(n)
614 .map(|node| node.is_anomaly)
615 .unwrap_or(false)
616 })
617 .count();
618 let anomaly_rate = anomalous_nodes as f64 / members.len() as f64;
619
620 if has_ownership {
622 GroupType::Intercompany
623 } else if has_approval {
624 GroupType::ApprovalChain
625 } else if has_cycles && anomaly_rate > 0.5 {
626 GroupType::MuleNetwork
627 } else if has_cycles {
628 GroupType::VendorRing
629 } else if anomaly_rate > 0.3 {
630 GroupType::MuleNetwork
631 } else {
632 GroupType::TransactionCluster
633 }
634}
635
636fn calculate_group_metrics(graph: &Graph, members: &[NodeId]) -> (f64, f64, f64) {
638 let member_set: HashSet<NodeId> = members.iter().copied().collect();
639
640 let mut internal_volume = 0.0;
641 let mut external_volume = 0.0;
642 let mut internal_edges = 0;
643
644 for &member in members {
645 for edge in graph.outgoing_edges(member) {
646 if member_set.contains(&edge.target) {
647 internal_volume += edge.weight;
648 internal_edges += 1;
649 } else {
650 external_volume += edge.weight;
651 }
652 }
653
654 for edge in graph.incoming_edges(member) {
655 if !member_set.contains(&edge.source) {
656 external_volume += edge.weight;
657 }
658 }
659 }
660
661 let max_possible_edges = members.len() * (members.len() - 1);
663 let cohesion = if max_possible_edges > 0 {
664 internal_edges as f64 / max_possible_edges as f64
665 } else {
666 0.0
667 };
668
669 (internal_volume, external_volume, cohesion)
670}
671
672fn find_hub_node(graph: &Graph, members: &[NodeId]) -> NodeId {
674 let member_set: HashSet<NodeId> = members.iter().copied().collect();
675
676 members
677 .iter()
678 .map(|&n| {
679 let internal_degree = graph
680 .neighbors(n)
681 .iter()
682 .filter(|neighbor| member_set.contains(neighbor))
683 .count();
684 (n, internal_degree)
685 })
686 .max_by_key(|(_, degree)| *degree)
687 .map(|(n, _)| n)
688 .unwrap_or(members[0])
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694 use crate::models::{GraphEdge, GraphNode, GraphType, NodeType};
695
696 fn create_test_graph() -> Graph {
697 let mut graph = Graph::new("test", GraphType::Transaction);
698
699 let n1 = graph.add_node(GraphNode::new(
702 0,
703 NodeType::Account,
704 "A".to_string(),
705 "A".to_string(),
706 ));
707 let n2 = graph.add_node(GraphNode::new(
708 0,
709 NodeType::Account,
710 "B".to_string(),
711 "B".to_string(),
712 ));
713 let n3 = graph.add_node(GraphNode::new(
714 0,
715 NodeType::Account,
716 "C".to_string(),
717 "C".to_string(),
718 ));
719
720 graph.add_edge(GraphEdge::new(0, n1, n2, EdgeType::Transaction).with_weight(100.0));
721 graph.add_edge(GraphEdge::new(0, n2, n3, EdgeType::Transaction).with_weight(100.0));
722 graph.add_edge(GraphEdge::new(0, n3, n1, EdgeType::Transaction).with_weight(100.0));
723
724 let n4 = graph.add_node(GraphNode::new(
726 0,
727 NodeType::Account,
728 "D".to_string(),
729 "D".to_string(),
730 ));
731 let n5 = graph.add_node(GraphNode::new(
732 0,
733 NodeType::Account,
734 "E".to_string(),
735 "E".to_string(),
736 ));
737 let n6 = graph.add_node(GraphNode::new(
738 0,
739 NodeType::Account,
740 "F".to_string(),
741 "F".to_string(),
742 ));
743
744 graph.add_edge(GraphEdge::new(0, n4, n5, EdgeType::Transaction).with_weight(200.0));
745 graph.add_edge(GraphEdge::new(0, n5, n6, EdgeType::Transaction).with_weight(200.0));
746
747 graph
748 }
749
750 #[test]
751 fn test_connected_components() {
752 let graph = create_test_graph();
753 let config = GroupDetectionConfig::default();
754
755 let result = detect_entity_groups(&graph, &config);
756
757 assert!(result.total_groups >= 1);
759 }
760
761 #[test]
762 fn test_label_propagation() {
763 let graph = create_test_graph();
764 let config = GroupDetectionConfig {
765 algorithms: vec![GroupDetectionAlgorithm::LabelPropagation],
766 ..Default::default()
767 };
768
769 let result = detect_entity_groups(&graph, &config);
770
771 assert!(!result.groups.is_empty() || result.total_groups == 0);
773 }
774
775 #[test]
776 fn test_clique_detection() {
777 let graph = create_test_graph();
778 let config = GroupDetectionConfig {
779 algorithms: vec![GroupDetectionAlgorithm::CliqueDetection],
780 min_cohesion: 0.1, ..Default::default()
782 };
783
784 let result = detect_entity_groups(&graph, &config);
785
786 let cliques: Vec<_> = result.groups.iter().filter(|g| g.cohesion > 0.4).collect();
789 assert!(!cliques.is_empty());
790 }
791
792 #[test]
793 fn test_node_features() {
794 let graph = create_test_graph();
795 let config = GroupDetectionConfig::default();
796
797 let result = detect_entity_groups(&graph, &config);
798 let features = result.node_features(1);
799
800 assert_eq!(features.len(), GroupDetectionResult::feature_dim());
801 }
802
803 #[test]
804 fn test_group_metrics() {
805 let graph = create_test_graph();
806 let members = vec![1, 2, 3]; let (internal, external, cohesion) = calculate_group_metrics(&graph, &members);
809
810 assert!(internal > 0.0);
811 assert!(cohesion > 0.0);
812 }
813
814 #[test]
815 fn test_hub_detection() {
816 let mut graph = Graph::new("test", GraphType::Transaction);
817
818 let n1 = graph.add_node(GraphNode::new(
820 0,
821 NodeType::Account,
822 "Hub".to_string(),
823 "Hub".to_string(),
824 ));
825 let n2 = graph.add_node(GraphNode::new(
826 0,
827 NodeType::Account,
828 "A".to_string(),
829 "A".to_string(),
830 ));
831 let n3 = graph.add_node(GraphNode::new(
832 0,
833 NodeType::Account,
834 "B".to_string(),
835 "B".to_string(),
836 ));
837 let n4 = graph.add_node(GraphNode::new(
838 0,
839 NodeType::Account,
840 "C".to_string(),
841 "C".to_string(),
842 ));
843
844 graph.add_edge(GraphEdge::new(0, n1, n2, EdgeType::Transaction));
845 graph.add_edge(GraphEdge::new(0, n1, n3, EdgeType::Transaction));
846 graph.add_edge(GraphEdge::new(0, n1, n4, EdgeType::Transaction));
847
848 let members = vec![n1, n2, n3, n4];
849 let hub = find_hub_node(&graph, &members);
850
851 assert_eq!(hub, n1);
852 }
853}