Skip to main content

oxirs_graphrag/
community_detector.rs

1//! Graph community detection using a greedy label-propagation approach.
2//!
3//! This module implements a Louvain-inspired greedy community detection
4//! algorithm. Each node iteratively adopts the community label of its
5//! most-connected neighbour until the assignment stabilises or the maximum
6//! number of iterations is reached. Small communities below `min_community_size`
7//! are merged into their largest neighbour community.
8//!
9//! # Example
10//!
11//! ```rust
12//! use oxirs_graphrag::community_detector::{CommunityGraph, CommunityDetector};
13//!
14//! let mut graph = CommunityGraph::new();
15//! graph.add_node(1, "Alice");
16//! graph.add_node(2, "Bob");
17//! graph.add_node(3, "Carol");
18//! graph.add_edge(1, 2, 1.0);
19//! graph.add_edge(2, 3, 1.0);
20//!
21//! let detector = CommunityDetector::new(1, 100);
22//! let result = detector.detect(&mut graph);
23//! assert!(!result.communities.is_empty());
24//! assert!(result.modularity >= -1.0);
25//! ```
26
27use std::collections::HashMap;
28
29/// A node in the community graph
30#[derive(Debug, Clone)]
31pub struct GraphNode {
32    /// Unique node identifier
33    pub id: u64,
34    /// Human-readable label
35    pub label: String,
36    /// Assigned community ID (None until detection runs)
37    pub community: Option<u32>,
38}
39
40impl GraphNode {
41    /// Create a new unassigned graph node
42    pub fn new(id: u64, label: impl Into<String>) -> Self {
43        Self {
44            id,
45            label: label.into(),
46            community: None,
47        }
48    }
49}
50
51/// A weighted edge between two nodes
52#[derive(Debug, Clone)]
53pub struct GraphEdge {
54    pub from: u64,
55    pub to: u64,
56    pub weight: f64,
57}
58
59impl GraphEdge {
60    /// Create a new edge
61    pub fn new(from: u64, to: u64, weight: f64) -> Self {
62        Self { from, to, weight }
63    }
64}
65
66/// A detected community grouping
67#[derive(Debug, Clone)]
68pub struct Community {
69    /// Unique community ID
70    pub id: u32,
71    /// Node IDs that belong to this community
72    pub members: Vec<u64>,
73    /// Count of edges whose both endpoints are inside the community
74    pub internal_edges: u64,
75    /// Count of all edges incident to any node in the community
76    pub total_edges: u64,
77}
78
79impl Community {
80    /// Return the number of members
81    pub fn size(&self) -> usize {
82        self.members.len()
83    }
84}
85
86/// Undirected weighted graph with adjacency lists
87#[derive(Debug, Clone, Default)]
88pub struct CommunityGraph {
89    /// Node collection indexed by ID
90    pub nodes: HashMap<u64, GraphNode>,
91    /// All edges (stored once; adjacency is bidirectional)
92    pub edges: Vec<GraphEdge>,
93    /// Adjacency list: node_id → [(neighbour_id, weight)]
94    pub adjacency: HashMap<u64, Vec<(u64, f64)>>,
95}
96
97impl CommunityGraph {
98    /// Create an empty graph
99    pub fn new() -> Self {
100        Self {
101            nodes: HashMap::new(),
102            edges: Vec::new(),
103            adjacency: HashMap::new(),
104        }
105    }
106
107    /// Add a node; overwrites if the ID already exists
108    pub fn add_node(&mut self, id: u64, label: &str) {
109        self.nodes.insert(id, GraphNode::new(id, label));
110        self.adjacency.entry(id).or_default();
111    }
112
113    /// Add an undirected edge between `from` and `to` with the given weight.
114    /// If either node does not exist it is created with an empty label.
115    pub fn add_edge(&mut self, from: u64, to: u64, weight: f64) {
116        self.adjacency.entry(from).or_default();
117        self.adjacency.entry(to).or_default();
118
119        self.adjacency
120            .get_mut(&from)
121            .expect("from adjacency entry just inserted")
122            .push((to, weight));
123        self.adjacency
124            .get_mut(&to)
125            .expect("to adjacency entry just inserted")
126            .push((from, weight));
127
128        self.edges.push(GraphEdge::new(from, to, weight));
129    }
130
131    /// Sum of edge weights incident to a node (weighted degree)
132    pub fn degree(&self, node_id: u64) -> f64 {
133        self.adjacency
134            .get(&node_id)
135            .map(|neighbours| neighbours.iter().map(|(_, w)| w).sum())
136            .unwrap_or(0.0)
137    }
138
139    /// Sum of all edge weights in the graph (counted once per edge)
140    pub fn total_weight(&self) -> f64 {
141        self.edges.iter().map(|e| e.weight).sum()
142    }
143}
144
145/// Outcome of a community detection run
146#[derive(Debug, Clone)]
147pub struct DetectionResult {
148    /// Final communities (after merging small ones)
149    pub communities: Vec<Community>,
150    /// Newman–Girvan modularity Q ∈ [-1, 1]
151    pub modularity: f64,
152    /// Number of label-propagation iterations performed
153    pub iterations: u32,
154}
155
156/// Community detection engine
157#[derive(Debug, Clone)]
158pub struct CommunityDetector {
159    /// Communities smaller than this are merged into a neighbour community
160    pub min_community_size: usize,
161    /// Maximum number of label-propagation iterations
162    pub max_iterations: u32,
163    /// Modularity resolution parameter γ (1.0 = standard modularity)
164    pub resolution: f64,
165}
166
167impl CommunityDetector {
168    /// Create a detector with the given parameters.
169    /// `resolution` defaults to 1.0 (standard Newman–Girvan modularity).
170    pub fn new(min_community_size: usize, max_iterations: u32) -> Self {
171        Self {
172            min_community_size,
173            max_iterations,
174            resolution: 1.0,
175        }
176    }
177
178    /// Create a detector with a custom resolution parameter
179    pub fn with_resolution(mut self, resolution: f64) -> Self {
180        self.resolution = resolution;
181        self
182    }
183
184    /// Run community detection on `graph`.
185    ///
186    /// Each node's `community` field is updated in place.
187    /// Returns a [`DetectionResult`] with the final partition and quality metrics.
188    pub fn detect(&self, graph: &mut CommunityGraph) -> DetectionResult {
189        // Initialise: each node belongs to its own community
190        let node_ids: Vec<u64> = graph.nodes.keys().copied().collect();
191
192        if node_ids.is_empty() {
193            return DetectionResult {
194                communities: vec![],
195                modularity: 0.0,
196                iterations: 0,
197            };
198        }
199
200        // Assign initial community IDs (community_id == node index for simplicity)
201        let mut community_map: HashMap<u64, u32> = node_ids
202            .iter()
203            .enumerate()
204            .map(|(i, &id)| (id, i as u32))
205            .collect();
206
207        let mut iterations = 0u32;
208        let mut changed = true;
209
210        while changed && iterations < self.max_iterations {
211            changed = false;
212            iterations += 1;
213
214            // Process nodes in a deterministic order
215            let mut sorted_ids = node_ids.clone();
216            sorted_ids.sort_unstable();
217
218            for &node_id in &sorted_ids {
219                let best_community = self.best_neighbour_community(graph, node_id, &community_map);
220                let current = *community_map.get(&node_id).unwrap_or(&0);
221                if best_community != current {
222                    community_map.insert(node_id, best_community);
223                    changed = true;
224                }
225            }
226        }
227
228        // Write community assignments back to nodes
229        for (&node_id, &comm) in &community_map {
230            if let Some(node) = graph.nodes.get_mut(&node_id) {
231                node.community = Some(comm);
232            }
233        }
234
235        // Build Community structs
236        let mut community_members: HashMap<u32, Vec<u64>> = HashMap::new();
237        for (&node_id, &comm) in &community_map {
238            community_members.entry(comm).or_default().push(node_id);
239        }
240
241        let mut communities: Vec<Community> = community_members
242            .into_iter()
243            .map(|(comm_id, members)| {
244                let member_set: std::collections::HashSet<u64> = members.iter().copied().collect();
245                let internal = graph
246                    .edges
247                    .iter()
248                    .filter(|e| member_set.contains(&e.from) && member_set.contains(&e.to))
249                    .count() as u64;
250                let total = graph
251                    .edges
252                    .iter()
253                    .filter(|e| member_set.contains(&e.from) || member_set.contains(&e.to))
254                    .count() as u64;
255                Community {
256                    id: comm_id,
257                    members,
258                    internal_edges: internal,
259                    total_edges: total,
260                }
261            })
262            .collect();
263
264        // Sort for deterministic output
265        communities.sort_by_key(|c| c.id);
266
267        let mut result = DetectionResult {
268            communities,
269            modularity: 0.0,
270            iterations,
271        };
272
273        // Merge communities that are too small
274        self.merge_small_communities(&mut result, graph);
275
276        // Compute modularity after merging
277        result.modularity = self.compute_modularity(graph);
278
279        result
280    }
281
282    /// Find the community with the highest connection weight from `node_id`.
283    /// Falls back to the node's own current community if it has no neighbours.
284    fn best_neighbour_community(
285        &self,
286        graph: &CommunityGraph,
287        node_id: u64,
288        community_map: &HashMap<u64, u32>,
289    ) -> u32 {
290        let current = *community_map.get(&node_id).unwrap_or(&0);
291        let neighbours = match graph.adjacency.get(&node_id) {
292            Some(n) => n,
293            None => return current,
294        };
295
296        if neighbours.is_empty() {
297            return current;
298        }
299
300        // Accumulate total weight per neighbouring community
301        let mut comm_weight: HashMap<u32, f64> = HashMap::new();
302        for &(nb_id, weight) in neighbours {
303            let nb_comm = *community_map.get(&nb_id).unwrap_or(&0);
304            *comm_weight.entry(nb_comm).or_insert(0.0) += weight;
305        }
306
307        // Choose community with maximum weight
308        comm_weight
309            .into_iter()
310            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
311            .map(|(c, _)| c)
312            .unwrap_or(current)
313    }
314
315    /// Compute the Newman–Girvan modularity Q for the current community assignment.
316    ///
317    /// Q = (1/2m) Σ_{ij} [ A_ij − γ k_i k_j / 2m ] δ(c_i, c_j)
318    ///
319    /// where m is the total weight, k_i is the degree of node i, A_ij is the
320    /// edge weight between i and j (0 if no edge), and δ is 1 iff same community.
321    pub fn compute_modularity(&self, graph: &CommunityGraph) -> f64 {
322        let two_m = graph.total_weight() * 2.0;
323        if two_m == 0.0 {
324            return 0.0;
325        }
326
327        // Build community assignment lookup
328        let comm_of: HashMap<u64, u32> = graph
329            .nodes
330            .values()
331            .filter_map(|n| n.community.map(|c| (n.id, c)))
332            .collect();
333
334        // Build edge-weight lookup (sum for multi-edges)
335        let mut edge_weight_map: HashMap<(u64, u64), f64> = HashMap::new();
336        for e in &graph.edges {
337            let key = if e.from <= e.to {
338                (e.from, e.to)
339            } else {
340                (e.to, e.from)
341            };
342            *edge_weight_map.entry(key).or_insert(0.0) += e.weight;
343        }
344
345        let node_ids: Vec<u64> = graph.nodes.keys().copied().collect();
346        let mut q = 0.0_f64;
347
348        for &i in &node_ids {
349            let ci = match comm_of.get(&i) {
350                Some(&c) => c,
351                None => continue,
352            };
353            let ki = graph.degree(i);
354            for &j in &node_ids {
355                let cj = match comm_of.get(&j) {
356                    Some(&c) => c,
357                    None => continue,
358                };
359                if ci != cj {
360                    continue;
361                }
362                let key = if i <= j { (i, j) } else { (j, i) };
363                let a_ij = edge_weight_map.get(&key).copied().unwrap_or(0.0);
364                let kj = graph.degree(j);
365                q += a_ij - self.resolution * ki * kj / two_m;
366            }
367        }
368
369        q / two_m
370    }
371
372    /// Compute the modularity gain of moving `node_id` into `target_community`.
373    ///
374    /// Returns a positive value if the move improves modularity.
375    pub fn modularity_gain(
376        &self,
377        graph: &CommunityGraph,
378        node_id: u64,
379        target_community: u32,
380    ) -> f64 {
381        let two_m = graph.total_weight() * 2.0;
382        if two_m == 0.0 {
383            return 0.0;
384        }
385
386        let ki = graph.degree(node_id);
387
388        // Weight of edges from node_id to target community members
389        let k_in: f64 = graph
390            .adjacency
391            .get(&node_id)
392            .map(|neighbours| {
393                neighbours
394                    .iter()
395                    .filter_map(|(nb_id, w)| {
396                        graph.nodes.get(nb_id).and_then(|n| {
397                            if n.community == Some(target_community) {
398                                Some(*w)
399                            } else {
400                                None
401                            }
402                        })
403                    })
404                    .sum()
405            })
406            .unwrap_or(0.0);
407
408        // Total weight of edges incident to target community
409        let sigma_tot: f64 = graph
410            .nodes
411            .values()
412            .filter(|n| n.community == Some(target_community))
413            .map(|n| graph.degree(n.id))
414            .sum();
415
416        // ΔQ = [ k_in/m - γ * sigma_tot * ki / (2m²) ]
417        2.0 * k_in / two_m - self.resolution * sigma_tot * ki / (two_m * two_m)
418    }
419
420    /// Merge communities whose member count is below `min_community_size`.
421    ///
422    /// Each small community's members are reassigned to the community that
423    /// shares the most edge weight with them.  The reassignment is reflected
424    /// in `graph.nodes[*].community` and in `result.communities`.
425    pub fn merge_small_communities(
426        &self,
427        result: &mut DetectionResult,
428        graph: &mut CommunityGraph,
429    ) {
430        if self.min_community_size <= 1 {
431            return;
432        }
433
434        let small_ids: Vec<u32> = result
435            .communities
436            .iter()
437            .filter(|c| c.members.len() < self.min_community_size)
438            .map(|c| c.id)
439            .collect();
440
441        for small_id in small_ids {
442            // Find the target community with most edge-weight connection
443            let members: Vec<u64> = result
444                .communities
445                .iter()
446                .find(|c| c.id == small_id)
447                .map(|c| c.members.clone())
448                .unwrap_or_default();
449
450            if members.is_empty() {
451                continue;
452            }
453
454            // Accumulate weight to every other community
455            let mut target_weight: HashMap<u32, f64> = HashMap::new();
456            for &m in &members {
457                if let Some(neighbours) = graph.adjacency.get(&m) {
458                    for &(nb_id, weight) in neighbours {
459                        if let Some(node) = graph.nodes.get(&nb_id) {
460                            if let Some(nb_comm) = node.community {
461                                if nb_comm != small_id {
462                                    *target_weight.entry(nb_comm).or_insert(0.0) += weight;
463                                }
464                            }
465                        }
466                    }
467                }
468            }
469
470            let best_target = target_weight
471                .into_iter()
472                .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
473                .map(|(c, _)| c);
474
475            let Some(target_id) = best_target else {
476                // No neighbour community found — skip
477                continue;
478            };
479
480            // Reassign community in nodes
481            for &m in &members {
482                if let Some(node) = graph.nodes.get_mut(&m) {
483                    node.community = Some(target_id);
484                }
485            }
486
487            // Move members in the Community structs
488            let moved_members = members.clone();
489            if let Some(target) = result.communities.iter_mut().find(|c| c.id == target_id) {
490                target.members.extend_from_slice(&moved_members);
491            }
492
493            // Remove the small community
494            result.communities.retain(|c| c.id != small_id);
495        }
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    fn triangle_graph() -> CommunityGraph {
504        let mut g = CommunityGraph::new();
505        g.add_node(1, "A");
506        g.add_node(2, "B");
507        g.add_node(3, "C");
508        g.add_edge(1, 2, 1.0);
509        g.add_edge(2, 3, 1.0);
510        g.add_edge(3, 1, 1.0);
511        g
512    }
513
514    fn two_cliques() -> CommunityGraph {
515        // Clique 1: 1-2-3 strongly connected
516        // Clique 2: 4-5-6 strongly connected
517        // Weak bridge: 3-4
518        let mut g = CommunityGraph::new();
519        for id in 1..=6 {
520            g.add_node(id, &format!("n{}", id));
521        }
522        g.add_edge(1, 2, 10.0);
523        g.add_edge(2, 3, 10.0);
524        g.add_edge(1, 3, 10.0);
525        g.add_edge(4, 5, 10.0);
526        g.add_edge(5, 6, 10.0);
527        g.add_edge(4, 6, 10.0);
528        g.add_edge(3, 4, 0.1); // weak bridge
529        g
530    }
531
532    // --- CommunityGraph construction ---
533
534    #[test]
535    fn test_add_node_stores_label() {
536        let mut g = CommunityGraph::new();
537        g.add_node(42, "TestNode");
538        let node = g.nodes.get(&42).expect("node should exist");
539        assert_eq!(node.label, "TestNode");
540        assert_eq!(node.community, None);
541    }
542
543    #[test]
544    fn test_add_edge_updates_adjacency_both_directions() {
545        let mut g = CommunityGraph::new();
546        g.add_node(1, "A");
547        g.add_node(2, "B");
548        g.add_edge(1, 2, 3.0);
549
550        let adj1 = g.adjacency.get(&1).expect("adj[1] should exist");
551        assert!(adj1
552            .iter()
553            .any(|(nb, w)| *nb == 2 && (*w - 3.0).abs() < 1e-10));
554
555        let adj2 = g.adjacency.get(&2).expect("adj[2] should exist");
556        assert!(adj2
557            .iter()
558            .any(|(nb, w)| *nb == 1 && (*w - 3.0).abs() < 1e-10));
559    }
560
561    #[test]
562    fn test_degree_empty_node() {
563        let mut g = CommunityGraph::new();
564        g.add_node(1, "A");
565        assert_eq!(g.degree(1), 0.0);
566    }
567
568    #[test]
569    fn test_degree_with_edges() {
570        let g = triangle_graph();
571        // Each node connects to two others with weight 1.0 each
572        assert!((g.degree(1) - 2.0).abs() < 1e-10);
573    }
574
575    #[test]
576    fn test_total_weight_triangle() {
577        let g = triangle_graph();
578        assert!((g.total_weight() - 3.0).abs() < 1e-10);
579    }
580
581    #[test]
582    fn test_total_weight_empty() {
583        let g = CommunityGraph::new();
584        assert_eq!(g.total_weight(), 0.0);
585    }
586
587    #[test]
588    fn test_degree_missing_node_returns_zero() {
589        let g = CommunityGraph::new();
590        assert_eq!(g.degree(9999), 0.0);
591    }
592
593    // --- Detection on simple graphs ---
594
595    #[test]
596    fn test_detect_empty_graph_returns_empty() {
597        let mut g = CommunityGraph::new();
598        let detector = CommunityDetector::new(1, 100);
599        let result = detector.detect(&mut g);
600        assert!(result.communities.is_empty());
601        assert_eq!(result.iterations, 0);
602    }
603
604    #[test]
605    fn test_detect_single_node() {
606        let mut g = CommunityGraph::new();
607        g.add_node(1, "solo");
608        let detector = CommunityDetector::new(1, 100);
609        let result = detector.detect(&mut g);
610        assert_eq!(result.communities.len(), 1);
611        assert_eq!(result.communities[0].members.len(), 1);
612    }
613
614    #[test]
615    fn test_detect_triangle_assigns_communities() {
616        let mut g = triangle_graph();
617        let detector = CommunityDetector::new(1, 50);
618        let result = detector.detect(&mut g);
619        // All nodes should have a community assigned
620        for node in g.nodes.values() {
621            assert!(
622                node.community.is_some(),
623                "Node {} has no community",
624                node.id
625            );
626        }
627        assert!(!result.communities.is_empty());
628    }
629
630    #[test]
631    fn test_detect_two_cliques_partition() {
632        let mut g = two_cliques();
633        let detector = CommunityDetector::new(1, 200);
634        let result = detector.detect(&mut g);
635        // We expect either 1 or 2 communities (algorithm may or may not split)
636        assert!(!result.communities.is_empty());
637        // Total members across all communities == 6
638        let total_members: usize = result.communities.iter().map(|c| c.members.len()).sum();
639        assert_eq!(total_members, 6);
640    }
641
642    #[test]
643    fn test_detect_respects_max_iterations() {
644        let mut g = two_cliques();
645        let detector = CommunityDetector::new(1, 3); // very few iterations
646        let result = detector.detect(&mut g);
647        assert!(result.iterations <= 3);
648    }
649
650    #[test]
651    fn test_detect_covers_all_nodes() {
652        let mut g = two_cliques();
653        let detector = CommunityDetector::new(1, 100);
654        let result = detector.detect(&mut g);
655        let total_members: usize = result.communities.iter().map(|c| c.members.len()).sum();
656        assert_eq!(total_members, g.nodes.len());
657    }
658
659    // --- Modularity ---
660
661    #[test]
662    fn test_modularity_empty_graph() {
663        let g = CommunityGraph::new();
664        let detector = CommunityDetector::new(1, 10);
665        let q = detector.compute_modularity(&g);
666        assert_eq!(q, 0.0);
667    }
668
669    #[test]
670    fn test_modularity_in_range_after_detection() {
671        let mut g = two_cliques();
672        let detector = CommunityDetector::new(1, 100);
673        let result = detector.detect(&mut g);
674        // Modularity must lie in [-1, 1]
675        assert!(
676            result.modularity >= -1.0 && result.modularity <= 1.0,
677            "Q={} out of range",
678            result.modularity
679        );
680    }
681
682    #[test]
683    fn test_modularity_single_community_is_non_positive() {
684        let mut g = triangle_graph();
685        // Force all into community 0
686        for node in g.nodes.values_mut() {
687            node.community = Some(0);
688        }
689        let detector = CommunityDetector::new(1, 10);
690        let q = detector.compute_modularity(&g);
691        // All-in-one partition has Q <= 0 for most graphs
692        assert!(q <= 0.0 + 1e-10, "Expected Q <= 0, got {}", q);
693    }
694
695    // --- Modularity gain ---
696
697    #[test]
698    fn test_modularity_gain_returns_finite() {
699        let mut g = triangle_graph();
700        let detector = CommunityDetector::new(1, 50);
701        detector.detect(&mut g);
702        let gain = detector.modularity_gain(&g, 1, 0);
703        assert!(gain.is_finite());
704    }
705
706    #[test]
707    fn test_modularity_gain_empty_graph() {
708        let g = CommunityGraph::new();
709        let detector = CommunityDetector::new(1, 10);
710        let gain = detector.modularity_gain(&g, 1, 0);
711        assert_eq!(gain, 0.0);
712    }
713
714    // --- Small community merging ---
715
716    #[test]
717    fn test_merge_small_communities_removes_tiny() {
718        let mut g = CommunityGraph::new();
719        g.add_node(1, "A");
720        g.add_node(2, "B");
721        g.add_node(3, "C");
722        g.add_edge(1, 2, 5.0);
723        g.add_edge(2, 3, 1.0);
724
725        let detector = CommunityDetector::new(2, 100); // min size 2
726        let result = detector.detect(&mut g);
727
728        // All communities should have >= 2 members OR be the only community
729        for comm in &result.communities {
730            assert!(
731                comm.members.len() >= 2 || result.communities.len() == 1,
732                "Community {} has only {} member(s)",
733                comm.id,
734                comm.members.len()
735            );
736        }
737    }
738
739    #[test]
740    fn test_merge_preserves_total_node_count() {
741        let mut g = two_cliques();
742        let detector = CommunityDetector::new(3, 100);
743        let result = detector.detect(&mut g);
744        let total: usize = result.communities.iter().map(|c| c.members.len()).sum();
745        assert_eq!(total, 6, "All 6 nodes must appear in some community");
746    }
747
748    // --- CommunityDetector builder ---
749
750    #[test]
751    fn test_with_resolution_sets_field() {
752        let d = CommunityDetector::new(1, 50).with_resolution(0.5);
753        assert!((d.resolution - 0.5).abs() < 1e-10);
754    }
755
756    #[test]
757    fn test_default_resolution_is_one() {
758        let d = CommunityDetector::new(1, 50);
759        assert!((d.resolution - 1.0).abs() < 1e-10);
760    }
761
762    // --- Community struct ---
763
764    #[test]
765    fn test_community_size() {
766        let c = Community {
767            id: 0,
768            members: vec![1, 2, 3],
769            internal_edges: 3,
770            total_edges: 4,
771        };
772        assert_eq!(c.size(), 3);
773    }
774
775    // --- GraphNode ---
776
777    #[test]
778    fn test_graph_node_initial_community_is_none() {
779        let node = GraphNode::new(7, "test");
780        assert_eq!(node.community, None);
781    }
782
783    // --- Disconnected graph ---
784
785    #[test]
786    fn test_detect_disconnected_graph() {
787        let mut g = CommunityGraph::new();
788        // Two isolated nodes
789        g.add_node(1, "X");
790        g.add_node(2, "Y");
791        let detector = CommunityDetector::new(1, 50);
792        let result = detector.detect(&mut g);
793        let total: usize = result.communities.iter().map(|c| c.members.len()).sum();
794        assert_eq!(total, 2);
795    }
796
797    #[test]
798    fn test_detect_star_graph() {
799        let mut g = CommunityGraph::new();
800        g.add_node(0, "center");
801        for i in 1..=5 {
802            g.add_node(i, &format!("leaf{}", i));
803            g.add_edge(0, i, 1.0);
804        }
805        let detector = CommunityDetector::new(1, 100);
806        let result = detector.detect(&mut g);
807        let total: usize = result.communities.iter().map(|c| c.members.len()).sum();
808        assert_eq!(total, 6);
809    }
810}
811
812// ---------------------------------------------------------------------------
813// v1.1.0 round 16: String-keyed community detection API
814// ---------------------------------------------------------------------------
815
816/// A directed or undirected edge in the knowledge graph using String node names.
817#[derive(Debug, Clone, PartialEq)]
818pub struct KgEdge {
819    /// Source node identifier.
820    pub from: String,
821    /// Target node identifier.
822    pub to: String,
823    /// Edge weight (e.g. predicate frequency or confidence).
824    pub weight: f64,
825}
826
827impl KgEdge {
828    /// Create a new knowledge-graph edge.
829    pub fn new(from: impl Into<String>, to: impl Into<String>, weight: f64) -> Self {
830        Self {
831            from: from.into(),
832            to: to.into(),
833            weight,
834        }
835    }
836}
837
838/// A community returned by the string-keyed detector.
839#[derive(Debug, Clone)]
840pub struct KgCommunity {
841    /// Unique community identifier.
842    pub community_id: usize,
843    /// Names of nodes belonging to this community.
844    pub members: Vec<String>,
845}
846
847/// Configuration for the string-keyed community detector.
848#[derive(Debug, Clone)]
849pub struct KgDetectionConfig {
850    /// Minimum community size; communities below this are merged into the largest
851    /// neighbouring community.
852    pub min_community_size: usize,
853    /// Maximum number of communities to return (0 = unlimited).
854    pub max_communities: usize,
855}
856
857impl Default for KgDetectionConfig {
858    fn default() -> Self {
859        Self {
860            min_community_size: 1,
861            max_communities: 0,
862        }
863    }
864}
865
866/// Greedy community detector operating on String-keyed node sets.
867///
868/// Uses iterative label propagation: each node starts in its own community
869/// and repeatedly adopts the most frequent label among its neighbours.
870/// The process continues for a fixed number of rounds.
871pub struct KgCommunityDetector;
872
873impl Default for KgCommunityDetector {
874    fn default() -> Self {
875        Self::new()
876    }
877}
878
879impl KgCommunityDetector {
880    /// Create a new detector.
881    pub fn new() -> Self {
882        KgCommunityDetector
883    }
884
885    /// Detect communities using label propagation.
886    ///
887    /// * Nodes with no edges form singleton communities (or are merged if
888    ///   `config.min_community_size > 1`).
889    /// * When `config.max_communities > 0` only the largest communities are
890    ///   kept and the remaining nodes are merged into the largest one.
891    pub fn detect(
892        &self,
893        nodes: &[String],
894        edges: &[KgEdge],
895        config: &KgDetectionConfig,
896    ) -> Vec<KgCommunity> {
897        if nodes.is_empty() {
898            return Vec::new();
899        }
900
901        // Build adjacency: node_name → Vec<(neighbour, weight)>
902        let mut adjacency: std::collections::HashMap<&str, Vec<(&str, f64)>> =
903            std::collections::HashMap::new();
904        for n in nodes {
905            adjacency.entry(n.as_str()).or_default();
906        }
907        for e in edges {
908            adjacency
909                .entry(e.from.as_str())
910                .or_default()
911                .push((e.to.as_str(), e.weight));
912            adjacency
913                .entry(e.to.as_str())
914                .or_default()
915                .push((e.from.as_str(), e.weight));
916        }
917
918        // Initialise: each node in its own label group (indexed by position)
919        let mut labels: std::collections::HashMap<&str, usize> = nodes
920            .iter()
921            .enumerate()
922            .map(|(i, n)| (n.as_str(), i))
923            .collect();
924
925        // Propagation rounds (deterministic order for reproducibility)
926        let max_rounds = 20_usize;
927        for _ in 0..max_rounds {
928            let mut changed = false;
929            let node_order: Vec<&str> = nodes.iter().map(String::as_str).collect();
930            for &node in &node_order {
931                let neighbours = adjacency.get(node).cloned().unwrap_or_default();
932                if neighbours.is_empty() {
933                    continue;
934                }
935                // Tally neighbour labels weighted by edge weight
936                let mut tally: std::collections::HashMap<usize, f64> =
937                    std::collections::HashMap::new();
938                for (nb, w) in &neighbours {
939                    if let Some(&lbl) = labels.get(nb) {
940                        *tally.entry(lbl).or_insert(0.0) += w;
941                    }
942                }
943                // Adopt the label with highest total weight (ties: prefer current)
944                let current = *labels.get(node).unwrap_or(&0);
945                let best = tally
946                    .into_iter()
947                    .max_by(|(la, wa), (lb, wb)| {
948                        wa.partial_cmp(wb)
949                            .unwrap_or(std::cmp::Ordering::Equal)
950                            .then_with(|| {
951                                // Tie-break: prefer current label
952                                if *la == current {
953                                    std::cmp::Ordering::Greater
954                                } else if *lb == current {
955                                    std::cmp::Ordering::Less
956                                } else {
957                                    std::cmp::Ordering::Equal
958                                }
959                            })
960                    })
961                    .map(|(lbl, _)| lbl);
962
963                if let Some(best_lbl) = best {
964                    if best_lbl != current {
965                        if let Some(lbl) = labels.get_mut(node) {
966                            *lbl = best_lbl;
967                        }
968                        changed = true;
969                    }
970                }
971            }
972            if !changed {
973                break;
974            }
975        }
976
977        // Collect label → members
978        let mut groups: std::collections::HashMap<usize, Vec<String>> =
979            std::collections::HashMap::new();
980        for n in nodes {
981            let lbl = *labels.get(n.as_str()).unwrap_or(&0);
982            groups.entry(lbl).or_default().push(n.clone());
983        }
984
985        // Build initial community list (sorted for determinism)
986        let mut communities: Vec<KgCommunity> = {
987            let mut kvs: Vec<(usize, Vec<String>)> = groups.into_iter().collect();
988            kvs.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
989            kvs.into_iter()
990                .enumerate()
991                .map(|(id, (_, members))| KgCommunity {
992                    community_id: id,
993                    members,
994                })
995                .collect()
996        };
997
998        // Apply min_community_size: merge small communities into the largest
999        if config.min_community_size > 1 {
1000            let mut large: Vec<KgCommunity> = Vec::new();
1001            let mut small_members: Vec<String> = Vec::new();
1002            for c in communities {
1003                if c.members.len() >= config.min_community_size {
1004                    large.push(c);
1005                } else {
1006                    small_members.extend(c.members);
1007                }
1008            }
1009            if !small_members.is_empty() {
1010                if large.is_empty() {
1011                    large.push(KgCommunity {
1012                        community_id: 0,
1013                        members: small_members,
1014                    });
1015                } else {
1016                    large[0].members.extend(small_members);
1017                }
1018            }
1019            communities = large;
1020        }
1021
1022        // Apply max_communities cap
1023        if config.max_communities > 0 && communities.len() > config.max_communities {
1024            // Merge excess communities into the first (largest)
1025            let keep = config.max_communities;
1026            let excess: Vec<String> = communities.drain(keep..).flat_map(|c| c.members).collect();
1027            if !excess.is_empty() {
1028                communities[0].members.extend(excess);
1029            }
1030        }
1031
1032        // Re-assign sequential community_ids
1033        for (i, c) in communities.iter_mut().enumerate() {
1034            c.community_id = i;
1035        }
1036
1037        communities
1038    }
1039
1040    /// Return the community a node belongs to, if any.
1041    pub fn node_community<'a>(
1042        &self,
1043        node: &str,
1044        communities: &'a [KgCommunity],
1045    ) -> Option<&'a KgCommunity> {
1046        communities
1047            .iter()
1048            .find(|c| c.members.iter().any(|m| m == node))
1049    }
1050
1051    /// Compute the modularity Q of a community partition.
1052    ///
1053    /// Q = Σ_c [ (e_c / m) - (a_c / (2m))² ]
1054    /// where `e_c` is the fraction of edge weight inside community c,
1055    /// `a_c` is the sum of degrees in c, and `m` is the total edge weight.
1056    /// Returns 0.0 for empty graphs.
1057    pub fn modularity(
1058        &self,
1059        nodes: &[String],
1060        edges: &[KgEdge],
1061        communities: &[KgCommunity],
1062    ) -> f64 {
1063        if edges.is_empty() || nodes.is_empty() {
1064            return 0.0;
1065        }
1066
1067        // Total weight
1068        let total_weight: f64 = edges.iter().map(|e| e.weight).sum::<f64>() * 2.0;
1069        if total_weight < 1e-12 {
1070            return 0.0;
1071        }
1072        let m = total_weight / 2.0;
1073
1074        // Node → community map
1075        let mut node_comm: std::collections::HashMap<&str, usize> =
1076            std::collections::HashMap::new();
1077        for c in communities {
1078            for member in &c.members {
1079                node_comm.insert(member.as_str(), c.community_id);
1080            }
1081        }
1082
1083        // Degree of each node
1084        let mut degree: std::collections::HashMap<&str, f64> = std::collections::HashMap::new();
1085        for e in edges {
1086            *degree.entry(e.from.as_str()).or_insert(0.0) += e.weight;
1087            *degree.entry(e.to.as_str()).or_insert(0.0) += e.weight;
1088        }
1089
1090        let mut q = 0.0_f64;
1091        for e in edges {
1092            let ci = node_comm.get(e.from.as_str()).copied();
1093            let cj = node_comm.get(e.to.as_str()).copied();
1094            if ci == cj && ci.is_some() {
1095                let ki = degree.get(e.from.as_str()).copied().unwrap_or(0.0);
1096                let kj = degree.get(e.to.as_str()).copied().unwrap_or(0.0);
1097                // Count each undirected edge once (stored once → factor ×2 / 2m)
1098                q += e.weight / m - (ki * kj) / (2.0 * m * m);
1099            }
1100        }
1101        // Subtract penalty for cross-community pairs (already handled by absence)
1102        q
1103    }
1104
1105    /// Count intra-community edges (both endpoints in the same community).
1106    pub fn intra_edges(edges: &[KgEdge], community: &KgCommunity) -> usize {
1107        let members: std::collections::HashSet<&str> =
1108            community.members.iter().map(String::as_str).collect();
1109        edges
1110            .iter()
1111            .filter(|e| members.contains(e.from.as_str()) && members.contains(e.to.as_str()))
1112            .count()
1113    }
1114}
1115
1116#[cfg(test)]
1117mod kg_tests {
1118    use super::*;
1119
1120    fn edge(from: &str, to: &str, w: f64) -> KgEdge {
1121        KgEdge::new(from, to, w)
1122    }
1123
1124    fn nodes(names: &[&str]) -> Vec<String> {
1125        names.iter().map(|&s| s.to_string()).collect()
1126    }
1127
1128    fn default_config() -> KgDetectionConfig {
1129        KgDetectionConfig::default()
1130    }
1131
1132    fn det() -> KgCommunityDetector {
1133        KgCommunityDetector::new()
1134    }
1135
1136    // --- Empty / single node ---
1137
1138    #[test]
1139    fn test_empty_graph_returns_no_communities() {
1140        let d = det();
1141        let result = d.detect(&[], &[], &default_config());
1142        assert!(result.is_empty());
1143    }
1144
1145    #[test]
1146    fn test_single_node_forms_own_community() {
1147        let d = det();
1148        let ns = nodes(&["A"]);
1149        let result = d.detect(&ns, &[], &default_config());
1150        assert_eq!(result.len(), 1);
1151        assert_eq!(result[0].members, vec!["A"]);
1152    }
1153
1154    // --- Isolated nodes ---
1155
1156    #[test]
1157    fn test_isolated_nodes_each_own_community() {
1158        let d = det();
1159        let ns = nodes(&["A", "B", "C"]);
1160        let result = d.detect(&ns, &[], &default_config());
1161        // Each isolated node gets its own community
1162        let total: usize = result.iter().map(|c| c.members.len()).sum();
1163        assert_eq!(total, 3);
1164        assert_eq!(result.len(), 3);
1165    }
1166
1167    // --- Connected cliques ---
1168
1169    #[test]
1170    fn test_two_cliques_grouped() {
1171        let d = det();
1172        let ns = nodes(&["A", "B", "C", "X", "Y", "Z"]);
1173        let es = vec![
1174            edge("A", "B", 10.0),
1175            edge("B", "C", 10.0),
1176            edge("A", "C", 10.0),
1177            edge("X", "Y", 10.0),
1178            edge("Y", "Z", 10.0),
1179            edge("X", "Z", 10.0),
1180            // Weak link between cliques
1181            edge("C", "X", 0.1),
1182        ];
1183        let result = d.detect(&ns, &es, &default_config());
1184        // All 6 nodes should be covered
1185        let total: usize = result.iter().map(|c| c.members.len()).sum();
1186        assert_eq!(total, 6);
1187        // Should produce >= 1 community
1188        assert!(!result.is_empty());
1189    }
1190
1191    #[test]
1192    fn test_fully_connected_single_community() {
1193        let d = det();
1194        let ns = nodes(&["A", "B", "C"]);
1195        let es = vec![
1196            edge("A", "B", 5.0),
1197            edge("B", "C", 5.0),
1198            edge("A", "C", 5.0),
1199        ];
1200        let result = d.detect(&ns, &es, &default_config());
1201        let total: usize = result.iter().map(|c| c.members.len()).sum();
1202        assert_eq!(total, 3);
1203    }
1204
1205    // --- min_community_size ---
1206
1207    #[test]
1208    fn test_min_community_size_merges_singletons() {
1209        let d = det();
1210        let ns = nodes(&["A", "B", "C", "D"]);
1211        // No edges → all singletons
1212        let config = KgDetectionConfig {
1213            min_community_size: 2,
1214            max_communities: 0,
1215        };
1216        let result = d.detect(&ns, &[], &config);
1217        // Singletons merged: all in one or two communities, each >= 2
1218        let total: usize = result.iter().map(|c| c.members.len()).sum();
1219        assert_eq!(total, 4);
1220        for c in &result {
1221            assert!(
1222                c.members.len() >= 2 || result.len() == 1,
1223                "community too small: {}",
1224                c.members.len()
1225            );
1226        }
1227    }
1228
1229    #[test]
1230    fn test_min_community_size_one_no_effect() {
1231        let d = det();
1232        let ns = nodes(&["A", "B"]);
1233        let config = KgDetectionConfig {
1234            min_community_size: 1,
1235            max_communities: 0,
1236        };
1237        let result = d.detect(&ns, &[], &config);
1238        assert_eq!(result.len(), 2);
1239    }
1240
1241    // --- max_communities ---
1242
1243    #[test]
1244    fn test_max_communities_cap() {
1245        let d = det();
1246        let ns = nodes(&["A", "B", "C", "D", "E"]);
1247        // No edges → 5 communities; cap at 2
1248        let config = KgDetectionConfig {
1249            min_community_size: 1,
1250            max_communities: 2,
1251        };
1252        let result = d.detect(&ns, &[], &config);
1253        assert!(result.len() <= 2);
1254    }
1255
1256    #[test]
1257    fn test_max_communities_zero_means_unlimited() {
1258        let d = det();
1259        let ns = nodes(&["A", "B", "C"]);
1260        let config = KgDetectionConfig {
1261            min_community_size: 1,
1262            max_communities: 0,
1263        };
1264        let result = d.detect(&ns, &[], &config);
1265        // 3 isolated nodes → 3 communities
1266        assert_eq!(result.len(), 3);
1267    }
1268
1269    // --- node_community ---
1270
1271    #[test]
1272    fn test_node_community_found() {
1273        let d = det();
1274        let ns = nodes(&["A", "B"]);
1275        let result = d.detect(&ns, &[], &default_config());
1276        let comm = d.node_community("A", &result);
1277        assert!(comm.is_some());
1278    }
1279
1280    #[test]
1281    fn test_node_community_not_found() {
1282        let d = det();
1283        let ns = nodes(&["A"]);
1284        let result = d.detect(&ns, &[], &default_config());
1285        let comm = d.node_community("Z", &result);
1286        assert!(comm.is_none());
1287    }
1288
1289    // --- intra_edges ---
1290
1291    #[test]
1292    fn test_intra_edges_count() {
1293        let community = KgCommunity {
1294            community_id: 0,
1295            members: vec!["A".to_string(), "B".to_string()],
1296        };
1297        let edges = vec![
1298            edge("A", "B", 1.0),
1299            edge("A", "C", 1.0), // C not in community
1300        ];
1301        assert_eq!(KgCommunityDetector::intra_edges(&edges, &community), 1);
1302    }
1303
1304    #[test]
1305    fn test_intra_edges_all_internal() {
1306        let community = KgCommunity {
1307            community_id: 0,
1308            members: vec!["A".to_string(), "B".to_string(), "C".to_string()],
1309        };
1310        let edges = vec![
1311            edge("A", "B", 1.0),
1312            edge("B", "C", 1.0),
1313            edge("A", "C", 1.0),
1314        ];
1315        assert_eq!(KgCommunityDetector::intra_edges(&edges, &community), 3);
1316    }
1317
1318    #[test]
1319    fn test_intra_edges_none() {
1320        let community = KgCommunity {
1321            community_id: 0,
1322            members: vec!["A".to_string()],
1323        };
1324        let edges = vec![edge("B", "C", 1.0)];
1325        assert_eq!(KgCommunityDetector::intra_edges(&edges, &community), 0);
1326    }
1327
1328    // --- modularity ---
1329
1330    #[test]
1331    fn test_modularity_empty_graph_zero() {
1332        let d = det();
1333        let q = d.modularity(&[], &[], &[]);
1334        assert!((q).abs() < 1e-9);
1335    }
1336
1337    #[test]
1338    fn test_modularity_in_valid_range() {
1339        let d = det();
1340        let ns = nodes(&["A", "B", "C", "D"]);
1341        let es = vec![edge("A", "B", 1.0), edge("C", "D", 1.0)];
1342        let config = default_config();
1343        let comms = d.detect(&ns, &es, &config);
1344        let q = d.modularity(&ns, &es, &comms);
1345        // Q ∈ [-0.5, 1.0]
1346        assert!((-0.5..=1.0).contains(&q), "Q={q} out of range");
1347    }
1348
1349    #[test]
1350    fn test_modularity_no_edges_zero() {
1351        let d = det();
1352        let ns = nodes(&["A", "B"]);
1353        let comms = vec![KgCommunity {
1354            community_id: 0,
1355            members: vec!["A".to_string(), "B".to_string()],
1356        }];
1357        let q = d.modularity(&ns, &[], &comms);
1358        assert!((q).abs() < 1e-9);
1359    }
1360
1361    // --- KgEdge ---
1362
1363    #[test]
1364    fn test_kg_edge_new() {
1365        let e = KgEdge::new("from", "to", 2.5);
1366        assert_eq!(e.from, "from");
1367        assert_eq!(e.to, "to");
1368        assert!((e.weight - 2.5).abs() < 1e-9);
1369    }
1370
1371    #[test]
1372    fn test_kg_edge_clone() {
1373        let e = KgEdge::new("a", "b", 1.0);
1374        let c = e.clone();
1375        assert_eq!(c.from, "a");
1376    }
1377
1378    // --- KgDetectionConfig default ---
1379
1380    #[test]
1381    fn test_detection_config_default() {
1382        let cfg = KgDetectionConfig::default();
1383        assert_eq!(cfg.min_community_size, 1);
1384        assert_eq!(cfg.max_communities, 0);
1385    }
1386
1387    // --- Coverage: all nodes covered ---
1388
1389    #[test]
1390    fn test_all_nodes_in_some_community() {
1391        let d = det();
1392        let ns = nodes(&["N1", "N2", "N3", "N4", "N5"]);
1393        let es = vec![edge("N1", "N2", 1.0), edge("N3", "N4", 1.0)];
1394        let result = d.detect(&ns, &es, &default_config());
1395        let total: usize = result.iter().map(|c| c.members.len()).sum();
1396        assert_eq!(total, 5);
1397    }
1398}