1use std::collections::{HashMap, HashSet};
2
3use crate::graph::{Edge, GraphDb, Node};
4
5pub fn detect_communities(nodes: &[Node], edges: &[Edge]) -> HashMap<String, i64> {
10 if nodes.is_empty() {
11 return HashMap::new();
12 }
13
14 let mut communities: HashMap<String, i64> = HashMap::new();
15 let node_ids: Vec<&str> = nodes.iter().map(|n| n.id.as_str()).collect();
16 let node_idx: HashMap<&str, usize> = node_ids
17 .iter()
18 .enumerate()
19 .map(|(i, id)| (*id, i))
20 .collect();
21 let n = node_ids.len();
22
23 let mut part: Vec<usize> = (0..n).collect();
24
25 let mut adj: Vec<HashMap<usize, f64>> = vec![HashMap::new(); n];
26 let mut k: Vec<f64> = vec![0.0; n];
27 let mut total_weight: f64 = 0.0;
28
29 for edge in edges {
30 if let (Some(&src_idx), Some(&dst_idx)) = (
31 node_idx.get(edge.src.as_str()),
32 node_idx.get(edge.dst.as_str()),
33 ) {
34 if src_idx == dst_idx {
35 continue;
36 }
37 let w = edge.weight.max(0.01) * edge.confidence.max(0.01);
38 *adj[src_idx].entry(dst_idx).or_insert(0.0) += w;
39 *adj[dst_idx].entry(src_idx).or_insert(0.0) += w;
40 k[src_idx] += w;
41 k[dst_idx] += w;
42 total_weight += w;
43 }
44 }
45
46 if total_weight == 0.0 {
47 for (i, id) in node_ids.iter().enumerate() {
48 communities.insert(id.to_string(), i as i64 + 1);
49 }
50 return communities;
51 }
52
53 let m2 = total_weight;
54
55 let mut improved = true;
56 let mut iteration = 0;
57 let max_iterations = 50;
58 while improved && iteration < max_iterations {
59 improved = false;
60 iteration += 1;
61
62 let mut order: Vec<usize> = (0..n).collect();
63 fastrand::shuffle(&mut order);
64
65 for &node in &order {
66 let current_comm = part[node];
67 let mut best_comm = current_comm;
68 let mut best_delta: f64 = 0.0;
69
70 let ki = k[node];
71 let ki_in_current = community_edge_weight(node, current_comm, &adj, &part);
72 let sigma_tot_current = total_community_weight(current_comm, &part, &k) - ki;
73
74 let mut neighbor_comms: HashSet<usize> = HashSet::new();
75 for &neighbor in adj[node].keys() {
76 neighbor_comms.insert(part[neighbor]);
77 }
78
79 for &neighbor_comm in &neighbor_comms {
80 if neighbor_comm == current_comm {
81 continue;
82 }
83
84 let ki_in_neighbor = community_edge_weight(node, neighbor_comm, &adj, &part);
85 let sigma_tot_neighbor = total_community_weight(neighbor_comm, &part, &k);
86
87 let delta = modularity_delta(ki_in_neighbor, sigma_tot_neighbor, ki, m2)
88 - modularity_delta(ki_in_current, sigma_tot_current, ki, m2);
89
90 if delta > best_delta {
91 best_delta = delta;
92 best_comm = neighbor_comm;
93 }
94 }
95
96 if best_comm != current_comm {
97 part[node] = best_comm;
98 improved = true;
99 }
100 }
101 }
102
103 let mut comm_map: HashMap<usize, i64> = HashMap::new();
104 let mut next_comm_id: i64 = 1;
105 for &c in &part {
106 comm_map.entry(c).or_insert_with(|| {
107 let id = next_comm_id;
108 next_comm_id += 1;
109 id
110 });
111 }
112
113 for (i, id) in node_ids.iter().enumerate() {
114 let comm = comm_map[&part[i]];
115 communities.insert(id.to_string(), comm);
116 }
117
118 communities
119}
120
121fn community_edge_weight(
122 node: usize,
123 community: usize,
124 adj: &[HashMap<usize, f64>],
125 part: &[usize],
126) -> f64 {
127 let mut weight = 0.0;
128 for (&neighbor, &w) in &adj[node] {
129 if part[neighbor] == community {
130 weight += w;
131 }
132 }
133 weight
134}
135
136fn total_community_weight(community: usize, part: &[usize], k: &[f64]) -> f64 {
137 let mut total = 0.0;
138 for (i, &comm) in part.iter().enumerate() {
139 if comm == community {
140 total += k[i];
141 }
142 }
143 total
144}
145
146fn modularity_delta(ki_in: f64, sigma_tot: f64, ki: f64, m2: f64) -> f64 {
147 if m2 == 0.0 {
148 return 0.0;
149 }
150 (ki_in / m2) - (sigma_tot * ki / (2.0 * m2 * m2))
151}
152
153pub fn run_clustering(db: &GraphDb) -> anyhow::Result<usize> {
159 db.clear_communities()?;
160
161 let nodes = db.get_all_nodes()?;
162 let edges = db.get_all_edges()?;
163
164 if nodes.is_empty() {
165 return Ok(0);
166 }
167
168 let community_map = detect_communities(&nodes, &edges);
169
170 let community_count = {
171 let mut unique: HashSet<i64> = HashSet::new();
172 for &c in community_map.values() {
173 unique.insert(c);
174 }
175 unique.len()
176 };
177
178 db.update_node_communities(&community_map)?;
179
180 Ok(community_count)
181}