Skip to main content

phago_runtime/
community.rs

1//! Community detection via label propagation.
2//!
3//! Detects communities in the knowledge graph using a simple
4//! label propagation algorithm. Used to evaluate whether the
5//! self-organized Hebbian graph recovers ground-truth topic clusters.
6
7use crate::colony::Colony;
8use phago_core::topology::TopologyGraph;
9use phago_core::types::NodeId;
10use serde::Serialize;
11use std::collections::HashMap;
12
13/// A detected community in the knowledge graph.
14#[derive(Debug, Clone, Serialize)]
15pub struct Community {
16    pub id: usize,
17    pub members: Vec<String>,
18    pub size: usize,
19}
20
21/// Result of community detection.
22#[derive(Debug, Clone, Serialize)]
23pub struct CommunityResult {
24    pub communities: Vec<Community>,
25    /// Node label → community ID mapping.
26    pub assignments: HashMap<String, usize>,
27    pub total_nodes: usize,
28    pub num_communities: usize,
29}
30
31/// Run label propagation community detection.
32///
33/// Each node starts with its own label. In each iteration, each node
34/// adopts the label most common among its neighbors (weighted by edge weight).
35/// Converges when no labels change.
36///
37/// Uses edge weight thresholding: only edges above the median weight are
38/// considered during neighbor voting. This prunes weak cross-topic edges
39/// and preserves within-topic clusters, improving NMI.
40pub fn detect_communities(colony: &Colony, max_iterations: usize) -> CommunityResult {
41    let graph = colony.substrate().graph();
42    let all_nodes = graph.all_nodes();
43
44    if all_nodes.is_empty() {
45        return CommunityResult {
46            communities: Vec::new(),
47            assignments: HashMap::new(),
48            total_nodes: 0,
49            num_communities: 0,
50        };
51    }
52
53    // Compute edge weight threshold adaptively based on graph density.
54    // Dense graphs need aggressive pruning (90th percentile) to reveal
55    // community structure; sparse graphs use median.
56    let all_edges = graph.all_edges();
57    let weight_threshold = if all_edges.is_empty() {
58        0.0
59    } else {
60        let mut weights: Vec<f64> = all_edges.iter().map(|(_, _, e)| e.weight).collect();
61        weights.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
62        let n = all_nodes.len() as f64;
63        let density = if n > 1.0 {
64            (2.0 * all_edges.len() as f64) / (n * (n - 1.0))
65        } else {
66            0.0
67        };
68        // For dense graphs (density > 0.05), use 90th percentile to aggressively
69        // prune weak edges. Otherwise use 75th percentile.
70        let percentile = if density > 0.05 { 90 } else { 75 };
71        let idx = (weights.len() * percentile / 100).min(weights.len() - 1);
72        weights[idx]
73    };
74
75    // Initialize: each node gets its own label
76    let mut labels: HashMap<NodeId, usize> = HashMap::new();
77    let node_list: Vec<NodeId> = all_nodes.clone();
78    for (i, nid) in node_list.iter().enumerate() {
79        labels.insert(*nid, i);
80    }
81
82    // Iterate with shuffled node order per iteration (asynchronous LP)
83    for iter in 0..max_iterations {
84        let mut changed = false;
85
86        // Shuffle node processing order using Fisher-Yates with deterministic seed
87        let mut order: Vec<usize> = (0..node_list.len()).collect();
88        let mut seed: u64 = (iter as u64).wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
89        for i in (1..order.len()).rev() {
90            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
91            let j = (seed >> 33) as usize % (i + 1);
92            order.swap(i, j);
93        }
94
95        for &idx in &order {
96            let nid = &node_list[idx];
97            let neighbors = graph.neighbors(nid);
98            if neighbors.is_empty() {
99                continue;
100            }
101
102            // Weight-vote for neighbor labels, only considering edges above the median weight
103            let mut label_weights: HashMap<usize, f64> = HashMap::new();
104            for (neighbor_id, edge) in &neighbors {
105                if edge.weight < weight_threshold {
106                    continue; // Skip weak cross-topic edges
107                }
108                if let Some(&label) = labels.get(neighbor_id) {
109                    *label_weights.entry(label).or_insert(0.0) += edge.weight;
110                }
111            }
112
113            if label_weights.is_empty() {
114                continue; // No strong neighbors, keep current label
115            }
116
117            // Adopt the highest-weighted label
118            if let Some((&best_label, _)) = label_weights.iter()
119                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
120            {
121                let current = labels.get(nid).copied().unwrap_or(0);
122                if best_label != current {
123                    labels.insert(*nid, best_label);
124                    changed = true;
125                }
126            }
127        }
128
129        if !changed {
130            break;
131        }
132    }
133
134    // Build communities from final labels
135    let mut community_members: HashMap<usize, Vec<String>> = HashMap::new();
136    let mut assignments: HashMap<String, usize> = HashMap::new();
137
138    for nid in &node_list {
139        if let (Some(&label), Some(node)) = (labels.get(nid), graph.get_node(nid)) {
140            community_members.entry(label).or_default().push(node.label.clone());
141            assignments.insert(node.label.clone(), label);
142        }
143    }
144
145    // Renumber communities 0, 1, 2, ...
146    let mut renumber: HashMap<usize, usize> = HashMap::new();
147    let mut next_id = 0;
148    for old_id in community_members.keys() {
149        renumber.entry(*old_id).or_insert_with(|| {
150            let id = next_id;
151            next_id += 1;
152            id
153        });
154    }
155
156    let mut communities: Vec<Community> = community_members.into_iter()
157        .map(|(old_id, members)| {
158            let new_id = renumber[&old_id];
159            Community {
160                id: new_id,
161                size: members.len(),
162                members,
163            }
164        })
165        .collect();
166    communities.sort_by(|a, b| b.size.cmp(&a.size));
167
168    // Update assignments with new IDs
169    for val in assignments.values_mut() {
170        *val = renumber[val];
171    }
172
173    CommunityResult {
174        num_communities: communities.len(),
175        total_nodes: node_list.len(),
176        communities,
177        assignments,
178    }
179}
180
181/// Compute Normalized Mutual Information (NMI) between detected and ground-truth labels.
182///
183/// NMI ranges from 0 (no correlation) to 1 (perfect match).
184pub fn compute_nmi(
185    assignments: &HashMap<String, usize>,
186    ground_truth: &HashMap<String, String>,
187) -> f64 {
188    // Map ground truth categories to numeric IDs
189    let mut gt_labels: HashMap<String, usize> = HashMap::new();
190    let mut gt_next = 0;
191    let mut gt_assignments: HashMap<String, usize> = HashMap::new();
192    for (node, category) in ground_truth {
193        if !gt_labels.contains_key(category) {
194            gt_labels.insert(category.clone(), gt_next);
195            gt_next += 1;
196        }
197        gt_assignments.insert(node.clone(), gt_labels[category]);
198    }
199
200    // Find nodes present in both
201    let common_nodes: Vec<&String> = assignments.keys()
202        .filter(|k| gt_assignments.contains_key(*k))
203        .collect();
204
205    if common_nodes.is_empty() {
206        return 0.0;
207    }
208
209    let n = common_nodes.len() as f64;
210
211    // Count label co-occurrences
212    let mut detected_counts: HashMap<usize, f64> = HashMap::new();
213    let mut gt_counts: HashMap<usize, f64> = HashMap::new();
214    let mut joint_counts: HashMap<(usize, usize), f64> = HashMap::new();
215
216    for node in &common_nodes {
217        let d = assignments[*node];
218        let g = gt_assignments[*node];
219        *detected_counts.entry(d).or_insert(0.0) += 1.0;
220        *gt_counts.entry(g).or_insert(0.0) += 1.0;
221        *joint_counts.entry((d, g)).or_insert(0.0) += 1.0;
222    }
223
224    // Compute mutual information
225    let mut mi = 0.0;
226    for (&(d, g), &nij) in &joint_counts {
227        if nij > 0.0 {
228            let ni = detected_counts[&d];
229            let nj = gt_counts[&g];
230            mi += (nij / n) * ((n * nij) / (ni * nj)).ln();
231        }
232    }
233
234    // Compute entropies
235    let h_detected: f64 = detected_counts.values()
236        .map(|&c| if c > 0.0 { -(c / n) * (c / n).ln() } else { 0.0 })
237        .sum();
238    let h_gt: f64 = gt_counts.values()
239        .map(|&c| if c > 0.0 { -(c / n) * (c / n).ln() } else { 0.0 })
240        .sum();
241
242    // NMI = 2 * MI / (H_detected + H_gt)
243    let denominator = h_detected + h_gt;
244    if denominator < 1e-10 {
245        0.0
246    } else {
247        (2.0 * mi / denominator).clamp(0.0, 1.0)
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn nmi_perfect_match() {
257        let mut detected: HashMap<String, usize> = HashMap::new();
258        let mut gt: HashMap<String, String> = HashMap::new();
259        for i in 0..10 {
260            let name = format!("node_{}", i);
261            let cluster = i / 5;
262            let category = format!("cat_{}", cluster);
263            detected.insert(name.clone(), cluster);
264            gt.insert(name, category);
265        }
266        let nmi = compute_nmi(&detected, &gt);
267        assert!(nmi > 0.99, "NMI should be ~1.0 for perfect match: {}", nmi);
268    }
269}