phago_runtime/
community.rs1use crate::colony::Colony;
8use phago_core::topology::TopologyGraph;
9use phago_core::types::NodeId;
10use serde::Serialize;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Serialize)]
15pub struct Community {
16 pub id: usize,
17 pub members: Vec<String>,
18 pub size: usize,
19}
20
21#[derive(Debug, Clone, Serialize)]
23pub struct CommunityResult {
24 pub communities: Vec<Community>,
25 pub assignments: HashMap<String, usize>,
27 pub total_nodes: usize,
28 pub num_communities: usize,
29}
30
31pub fn detect_communities(colony: &Colony, max_iterations: usize) -> CommunityResult {
41 let graph = colony.substrate().graph();
42 let all_nodes = graph.all_nodes();
43
44 if all_nodes.is_empty() {
45 return CommunityResult {
46 communities: Vec::new(),
47 assignments: HashMap::new(),
48 total_nodes: 0,
49 num_communities: 0,
50 };
51 }
52
53 let all_edges = graph.all_edges();
57 let weight_threshold = if all_edges.is_empty() {
58 0.0
59 } else {
60 let mut weights: Vec<f64> = all_edges.iter().map(|(_, _, e)| e.weight).collect();
61 weights.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
62 let n = all_nodes.len() as f64;
63 let density = if n > 1.0 {
64 (2.0 * all_edges.len() as f64) / (n * (n - 1.0))
65 } else {
66 0.0
67 };
68 let percentile = if density > 0.05 { 90 } else { 75 };
71 let idx = (weights.len() * percentile / 100).min(weights.len() - 1);
72 weights[idx]
73 };
74
75 let mut labels: HashMap<NodeId, usize> = HashMap::new();
77 let node_list: Vec<NodeId> = all_nodes.clone();
78 for (i, nid) in node_list.iter().enumerate() {
79 labels.insert(*nid, i);
80 }
81
82 for iter in 0..max_iterations {
84 let mut changed = false;
85
86 let mut order: Vec<usize> = (0..node_list.len()).collect();
88 let mut seed: u64 = (iter as u64).wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
89 for i in (1..order.len()).rev() {
90 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
91 let j = (seed >> 33) as usize % (i + 1);
92 order.swap(i, j);
93 }
94
95 for &idx in &order {
96 let nid = &node_list[idx];
97 let neighbors = graph.neighbors(nid);
98 if neighbors.is_empty() {
99 continue;
100 }
101
102 let mut label_weights: HashMap<usize, f64> = HashMap::new();
104 for (neighbor_id, edge) in &neighbors {
105 if edge.weight < weight_threshold {
106 continue; }
108 if let Some(&label) = labels.get(neighbor_id) {
109 *label_weights.entry(label).or_insert(0.0) += edge.weight;
110 }
111 }
112
113 if label_weights.is_empty() {
114 continue; }
116
117 if let Some((&best_label, _)) = label_weights.iter()
119 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
120 {
121 let current = labels.get(nid).copied().unwrap_or(0);
122 if best_label != current {
123 labels.insert(*nid, best_label);
124 changed = true;
125 }
126 }
127 }
128
129 if !changed {
130 break;
131 }
132 }
133
134 let mut community_members: HashMap<usize, Vec<String>> = HashMap::new();
136 let mut assignments: HashMap<String, usize> = HashMap::new();
137
138 for nid in &node_list {
139 if let (Some(&label), Some(node)) = (labels.get(nid), graph.get_node(nid)) {
140 community_members.entry(label).or_default().push(node.label.clone());
141 assignments.insert(node.label.clone(), label);
142 }
143 }
144
145 let mut renumber: HashMap<usize, usize> = HashMap::new();
147 let mut next_id = 0;
148 for old_id in community_members.keys() {
149 renumber.entry(*old_id).or_insert_with(|| {
150 let id = next_id;
151 next_id += 1;
152 id
153 });
154 }
155
156 let mut communities: Vec<Community> = community_members.into_iter()
157 .map(|(old_id, members)| {
158 let new_id = renumber[&old_id];
159 Community {
160 id: new_id,
161 size: members.len(),
162 members,
163 }
164 })
165 .collect();
166 communities.sort_by(|a, b| b.size.cmp(&a.size));
167
168 for val in assignments.values_mut() {
170 *val = renumber[val];
171 }
172
173 CommunityResult {
174 num_communities: communities.len(),
175 total_nodes: node_list.len(),
176 communities,
177 assignments,
178 }
179}
180
181pub fn compute_nmi(
185 assignments: &HashMap<String, usize>,
186 ground_truth: &HashMap<String, String>,
187) -> f64 {
188 let mut gt_labels: HashMap<String, usize> = HashMap::new();
190 let mut gt_next = 0;
191 let mut gt_assignments: HashMap<String, usize> = HashMap::new();
192 for (node, category) in ground_truth {
193 if !gt_labels.contains_key(category) {
194 gt_labels.insert(category.clone(), gt_next);
195 gt_next += 1;
196 }
197 gt_assignments.insert(node.clone(), gt_labels[category]);
198 }
199
200 let common_nodes: Vec<&String> = assignments.keys()
202 .filter(|k| gt_assignments.contains_key(*k))
203 .collect();
204
205 if common_nodes.is_empty() {
206 return 0.0;
207 }
208
209 let n = common_nodes.len() as f64;
210
211 let mut detected_counts: HashMap<usize, f64> = HashMap::new();
213 let mut gt_counts: HashMap<usize, f64> = HashMap::new();
214 let mut joint_counts: HashMap<(usize, usize), f64> = HashMap::new();
215
216 for node in &common_nodes {
217 let d = assignments[*node];
218 let g = gt_assignments[*node];
219 *detected_counts.entry(d).or_insert(0.0) += 1.0;
220 *gt_counts.entry(g).or_insert(0.0) += 1.0;
221 *joint_counts.entry((d, g)).or_insert(0.0) += 1.0;
222 }
223
224 let mut mi = 0.0;
226 for (&(d, g), &nij) in &joint_counts {
227 if nij > 0.0 {
228 let ni = detected_counts[&d];
229 let nj = gt_counts[&g];
230 mi += (nij / n) * ((n * nij) / (ni * nj)).ln();
231 }
232 }
233
234 let h_detected: f64 = detected_counts.values()
236 .map(|&c| if c > 0.0 { -(c / n) * (c / n).ln() } else { 0.0 })
237 .sum();
238 let h_gt: f64 = gt_counts.values()
239 .map(|&c| if c > 0.0 { -(c / n) * (c / n).ln() } else { 0.0 })
240 .sum();
241
242 let denominator = h_detected + h_gt;
244 if denominator < 1e-10 {
245 0.0
246 } else {
247 (2.0 * mi / denominator).clamp(0.0, 1.0)
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn nmi_perfect_match() {
257 let mut detected: HashMap<String, usize> = HashMap::new();
258 let mut gt: HashMap<String, String> = HashMap::new();
259 for i in 0..10 {
260 let name = format!("node_{}", i);
261 let cluster = i / 5;
262 let category = format!("cat_{}", cluster);
263 detected.insert(name.clone(), cluster);
264 gt.insert(name, category);
265 }
266 let nmi = compute_nmi(&detected, >);
267 assert!(nmi > 0.99, "NMI should be ~1.0 for perfect match: {}", nmi);
268 }
269}