Skip to main content

graphify_cluster/
lib.rs

1//! Community detection (Leiden algorithm) for graphify.
2//!
3//! Partitions the knowledge graph into communities using the Leiden algorithm,
4//! which improves upon Louvain by adding a refinement phase that guarantees
5//! well-connected communities. Falls back to greedy modularity when refinement
6//! yields no improvement.
7
8use std::collections::{HashMap, HashSet, VecDeque};
9
10use tracing::debug;
11
12use graphify_core::graph::KnowledgeGraph;
13use graphify_core::model::CommunityInfo;
14
15/// Maximum fraction of total nodes a single community may contain before
16/// being split further.
17const MAX_COMMUNITY_FRACTION: f64 = 0.25;
18
19/// Minimum community size below which we never attempt a split.
20const MIN_SPLIT_SIZE: usize = 10;
21
22/// Minimum community size — communities smaller than this get merged into
23/// their most-connected neighbor.
24const MIN_COMMUNITY_SIZE: usize = 5;
25
26/// Resolution parameter for modularity. Lower = fewer, larger communities.
27/// Default Leiden uses 1.0; we use a lower value to avoid over-fragmentation.
28const RESOLUTION: f64 = 0.3;
29
30// ---------------------------------------------------------------------------
31// Public API
32// ---------------------------------------------------------------------------
33
34/// Run community detection on the graph. Returns `{community_id: [node_ids]}`.
35///
36/// Uses the Leiden algorithm: greedy modularity optimization (Louvain phase)
37/// followed by a refinement phase that ensures communities are internally
38/// well-connected.
39pub fn cluster(graph: &KnowledgeGraph) -> HashMap<usize, Vec<String>> {
40    let node_count = graph.node_count();
41    if node_count == 0 {
42        return HashMap::new();
43    }
44
45    // If no edges, each node is its own community
46    if graph.edge_count() == 0 {
47        return graph
48            .node_ids()
49            .into_iter()
50            .enumerate()
51            .map(|(i, id)| (i, vec![id]))
52            .collect();
53    }
54
55    let partition = leiden_partition(graph);
56
57    // Group by community
58    let mut communities: HashMap<usize, Vec<String>> = HashMap::new();
59    for (node_id, cid) in &partition {
60        communities.entry(*cid).or_default().push(node_id.clone());
61    }
62
63    // Merge tiny communities into their most-connected neighbor
64    let adj = build_adjacency(graph);
65    merge_small_communities(&mut communities, &adj);
66
67    // Split oversized communities
68    let max_size = std::cmp::max(
69        MIN_SPLIT_SIZE,
70        (node_count as f64 * MAX_COMMUNITY_FRACTION) as usize,
71    );
72    let mut final_communities: Vec<Vec<String>> = Vec::new();
73    for nodes in communities.values() {
74        if nodes.len() > max_size {
75            final_communities.extend(split_community(graph, nodes));
76        } else {
77            final_communities.push(nodes.clone());
78        }
79    }
80
81    // Re-index by size descending
82    final_communities.sort_by_key(|b| std::cmp::Reverse(b.len()));
83    final_communities
84        .into_iter()
85        .enumerate()
86        .map(|(i, mut nodes)| {
87            nodes.sort();
88            (i, nodes)
89        })
90        .collect()
91}
92
93/// Run community detection and mutate graph in-place, storing community info.
94pub fn cluster_graph(graph: &mut KnowledgeGraph) -> HashMap<usize, Vec<String>> {
95    let communities = cluster(graph);
96
97    // Build CommunityInfo entries
98    let scores = score_all(graph, &communities);
99    let mut infos: Vec<CommunityInfo> = communities
100        .iter()
101        .map(|(&cid, nodes)| CommunityInfo {
102            id: cid,
103            nodes: nodes.clone(),
104            cohesion: scores.get(&cid).copied().unwrap_or(0.0),
105            label: None,
106        })
107        .collect();
108    infos.sort_by_key(|c| c.id);
109    graph.communities = infos;
110
111    communities
112}
113
114/// Cohesion score: ratio of actual intra-community edges to maximum possible.
115///
116/// Returns a value in `[0.0, 1.0]` rounded to two decimal places.
117pub fn cohesion_score(graph: &KnowledgeGraph, community_nodes: &[String]) -> f64 {
118    let n = community_nodes.len();
119    if n <= 1 {
120        return 1.0;
121    }
122
123    let node_set: HashSet<&str> = community_nodes.iter().map(|s| s.as_str()).collect();
124    let mut actual_edges = 0usize;
125
126    // Count edges where both endpoints are in the community
127    for node_id in community_nodes {
128        let neighbors = graph.get_neighbors(node_id);
129        for neighbor in &neighbors {
130            if node_set.contains(neighbor.id.as_str()) {
131                actual_edges += 1;
132            }
133        }
134    }
135    // Each edge is counted twice (once from each endpoint)
136    actual_edges /= 2;
137
138    let possible = n * (n - 1) / 2;
139    if possible == 0 {
140        return 0.0;
141    }
142    ((actual_edges as f64 / possible as f64) * 100.0).round() / 100.0
143}
144
145/// Compute cohesion scores for all communities.
146pub fn score_all(
147    graph: &KnowledgeGraph,
148    communities: &HashMap<usize, Vec<String>>,
149) -> HashMap<usize, f64> {
150    communities
151        .iter()
152        .map(|(&cid, nodes)| (cid, cohesion_score(graph, nodes)))
153        .collect()
154}
155
156// ---------------------------------------------------------------------------
157// Adjacency helpers
158// ---------------------------------------------------------------------------
159
160/// Build an adjacency list from the KnowledgeGraph for efficient lookups.
161fn build_adjacency(graph: &KnowledgeGraph) -> HashMap<String, Vec<(String, f64)>> {
162    let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
163    for id in graph.node_ids() {
164        adj.entry(id).or_default();
165    }
166    for (src, tgt, edge) in graph.edges_with_endpoints() {
167        adj.entry(src.to_string())
168            .or_default()
169            .push((tgt.to_string(), edge.weight));
170        adj.entry(tgt.to_string())
171            .or_default()
172            .push((src.to_string(), edge.weight));
173    }
174    adj
175}
176
177/// Compute total weight of edges in the graph (sum of all edge weights).
178fn total_weight(adj: &HashMap<String, Vec<(String, f64)>>) -> f64 {
179    let mut m = 0.0;
180    for neighbors in adj.values() {
181        for (_, w) in neighbors {
182            m += w;
183        }
184    }
185    m / 2.0 // each edge counted twice
186}
187
188/// Sum of weights of edges incident to a node.
189fn node_strength(adj: &HashMap<String, Vec<(String, f64)>>, node: &str) -> f64 {
190    adj.get(node)
191        .map(|neighbors| neighbors.iter().map(|(_, w)| w).sum())
192        .unwrap_or(0.0)
193}
194
195/// Sum of weights of edges from `node` to nodes in `community`.
196fn edges_to_community(
197    adj: &HashMap<String, Vec<(String, f64)>>,
198    node: &str,
199    community: &HashSet<&str>,
200) -> f64 {
201    adj.get(node)
202        .map(|neighbors| {
203            neighbors
204                .iter()
205                .filter(|(n, _)| community.contains(n.as_str()))
206                .map(|(_, w)| w)
207                .sum()
208        })
209        .unwrap_or(0.0)
210}
211
212/// Sum of strengths of all nodes in a community.
213fn community_strength(adj: &HashMap<String, Vec<(String, f64)>>, members: &HashSet<&str>) -> f64 {
214    members.iter().map(|n| node_strength(adj, n)).sum()
215}
216
217// ---------------------------------------------------------------------------
218// Leiden algorithm
219// ---------------------------------------------------------------------------
220
221/// Leiden algorithm: Louvain phase + refinement phase, iterated until stable.
222///
223/// Reference: Traag, Waltman & van Eck (2019) "From Louvain to Leiden:
224/// guaranteeing well-connected communities"
225fn leiden_partition(graph: &KnowledgeGraph) -> HashMap<String, usize> {
226    let adj = build_adjacency(graph);
227    let m = total_weight(&adj);
228    if m == 0.0 {
229        return graph
230            .node_ids()
231            .into_iter()
232            .enumerate()
233            .map(|(i, id)| (id, i))
234            .collect();
235    }
236
237    let node_ids = graph.node_ids();
238
239    // Initialize: each node in its own community
240    let mut community_of: HashMap<String, usize> = node_ids
241        .iter()
242        .enumerate()
243        .map(|(i, id)| (id.clone(), i))
244        .collect();
245
246    let max_outer_iterations = 10;
247    for _outer in 0..max_outer_iterations {
248        // ── Phase 1: Louvain (greedy modularity move) ──
249        let changed = louvain_phase(&adj, &node_ids, &mut community_of, m);
250
251        // ── Phase 2: Refinement (ensure well-connected communities) ──
252        let refined = refinement_phase(&adj, &mut community_of, m);
253
254        if !changed && !refined {
255            break;
256        }
257    }
258
259    // Compact community IDs
260    compact_ids(&mut community_of);
261    community_of
262}
263
264/// Phase 1: Greedy modularity optimization (Louvain move phase).
265///
266/// Iterates over nodes and moves each to the neighboring community that
267/// yields the greatest modularity gain. Returns true if any move was made.
268fn louvain_phase(
269    adj: &HashMap<String, Vec<(String, f64)>>,
270    node_ids: &[String],
271    community_of: &mut HashMap<String, usize>,
272    m: f64,
273) -> bool {
274    let mut community_members: HashMap<usize, HashSet<String>> = HashMap::new();
275    for (node, &cid) in community_of.iter() {
276        community_members
277            .entry(cid)
278            .or_default()
279            .insert(node.clone());
280    }
281
282    // Pre-compute node strengths (ki) — avoids repeated O(deg) recalculation
283    let ki_cache: HashMap<&str, f64> = adj
284        .keys()
285        .map(|n| (n.as_str(), node_strength(adj, n)))
286        .collect();
287
288    // Pre-compute community strength sums (sigma_c) — maintained incrementally
289    let mut sigma_c: HashMap<usize, f64> = HashMap::new();
290    for (&cid, members) in &community_members {
291        let sum: f64 = members
292            .iter()
293            .map(|n| ki_cache.get(n.as_str()).copied().unwrap_or(0.0))
294            .sum();
295        sigma_c.insert(cid, sum);
296    }
297
298    let max_iterations = 50;
299    let mut any_changed = false;
300
301    for _iteration in 0..max_iterations {
302        let mut improved = false;
303
304        for node in node_ids {
305            let current_community = community_of[node];
306            let ki = ki_cache.get(node.as_str()).copied().unwrap_or(0.0);
307
308            // Aggregate edges to each neighboring community in one pass
309            let mut ki_to: HashMap<usize, f64> = HashMap::new();
310            if let Some(neighbors) = adj.get(node.as_str()) {
311                for (nbr, w) in neighbors {
312                    let nbr_cid = community_of[nbr];
313                    *ki_to.entry(nbr_cid).or_default() += w;
314                }
315            }
316
317            let mut best_community = current_community;
318            let mut best_gain = 0.0f64;
319
320            // ki to current community (excluding self-loop edges already handled)
321            let ki_in_current = ki_to.get(&current_community).copied().unwrap_or(0.0);
322            let sigma_current = sigma_c.get(&current_community).copied().unwrap_or(0.0) - ki;
323
324            for (&target_community, &ki_in_target) in &ki_to {
325                if target_community == current_community {
326                    continue;
327                }
328
329                let sigma_target = sigma_c.get(&target_community).copied().unwrap_or(0.0);
330
331                let gain = (ki_in_target - ki_in_current) / m
332                    - RESOLUTION * ki * (sigma_target - sigma_current) / (2.0 * m * m);
333
334                if gain > best_gain {
335                    best_gain = gain;
336                    best_community = target_community;
337                }
338            }
339
340            if best_community != current_community {
341                // Update community_members
342                community_members
343                    .get_mut(&current_community)
344                    .unwrap()
345                    .remove(node);
346                community_members
347                    .entry(best_community)
348                    .or_default()
349                    .insert(node.clone());
350
351                // Update sigma_c incrementally
352                *sigma_c.entry(current_community).or_default() -= ki;
353                *sigma_c.entry(best_community).or_default() += ki;
354
355                community_of.insert(node.clone(), best_community);
356                improved = true;
357                any_changed = true;
358            }
359        }
360
361        if !improved {
362            break;
363        }
364    }
365
366    any_changed
367}
368
369/// Phase 2: Leiden refinement.
370///
371/// For each community, find its connected components. If a community is
372/// internally disconnected, split it — move each disconnected sub-component
373/// to whichever neighboring community maximizes modularity gain (or keep it
374/// as a new community). This guarantees all resulting communities are
375/// internally connected.
376fn refinement_phase(
377    adj: &HashMap<String, Vec<(String, f64)>>,
378    community_of: &mut HashMap<String, usize>,
379    m: f64,
380) -> bool {
381    // Group nodes by community
382    let mut community_members: HashMap<usize, Vec<String>> = HashMap::new();
383    for (node, &cid) in community_of.iter() {
384        community_members.entry(cid).or_default().push(node.clone());
385    }
386
387    let mut any_refined = false;
388    let mut next_cid = community_members.keys().copied().max().unwrap_or(0) + 1;
389
390    let community_ids: Vec<usize> = community_members.keys().copied().collect();
391    for cid in community_ids {
392        let members = match community_members.get(&cid) {
393            Some(m) if m.len() > 1 => m.clone(),
394            _ => continue,
395        };
396
397        // Find connected components within this community
398        let components = connected_components_within(adj, &members);
399        if components.len() <= 1 {
400            continue; // Already well-connected
401        }
402
403        debug!(
404            "Leiden refinement: community {} has {} disconnected components, splitting",
405            cid,
406            components.len()
407        );
408
409        // Keep the largest component in the original community,
410        // assign each smaller component to the best neighboring community
411        // or a new community.
412        let mut sorted_components = components;
413        sorted_components.sort_by_key(|c| std::cmp::Reverse(c.len()));
414
415        // Largest component stays
416        for component in sorted_components.iter().skip(1) {
417            // For this sub-component, find the best neighboring community
418            let mut neighbor_cid_edges: HashMap<usize, f64> = HashMap::new();
419            for node in component {
420                if let Some(neighbors) = adj.get(node.as_str()) {
421                    for (nbr, w) in neighbors {
422                        let nbr_cid = community_of[nbr];
423                        if nbr_cid != cid {
424                            *neighbor_cid_edges.entry(nbr_cid).or_default() += w;
425                        }
426                    }
427                }
428            }
429
430            // Pick the neighbor community with the strongest connection,
431            // or create a new community if no neighbor exists
432            let target_cid = if let Some((&best_cid, _)) = neighbor_cid_edges
433                .iter()
434                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
435            {
436                // Only merge if modularity gain is positive
437                let _component_set: HashSet<&str> = component.iter().map(|s| s.as_str()).collect();
438                let target_members: HashSet<&str> = community_members
439                    .get(&best_cid)
440                    .map(|s| s.iter().map(|x| x.as_str()).collect())
441                    .unwrap_or_default();
442
443                let ki_sum: f64 = component.iter().map(|n| node_strength(adj, n)).sum();
444                let ki_in = component
445                    .iter()
446                    .map(|n| edges_to_community(adj, n, &target_members))
447                    .sum::<f64>();
448                let sigma_t = community_strength(adj, &target_members);
449
450                let gain = ki_in / m - ki_sum * sigma_t / (2.0 * m * m);
451                if gain > 0.0 {
452                    best_cid
453                } else {
454                    let new_cid = next_cid;
455                    next_cid += 1;
456                    new_cid
457                }
458            } else {
459                let new_cid = next_cid;
460                next_cid += 1;
461                new_cid
462            };
463
464            // Move all nodes in this component
465            for node in component {
466                community_of.insert(node.clone(), target_cid);
467                community_members
468                    .entry(target_cid)
469                    .or_default()
470                    .push(node.clone());
471            }
472            any_refined = true;
473        }
474
475        // Update original community to only keep largest component
476        if any_refined {
477            community_members.insert(cid, sorted_components.into_iter().next().unwrap());
478        }
479    }
480
481    any_refined
482}
483
484/// Find connected components within a subset of nodes using BFS.
485fn connected_components_within(
486    adj: &HashMap<String, Vec<(String, f64)>>,
487    members: &[String],
488) -> Vec<Vec<String>> {
489    let member_set: HashSet<&str> = members.iter().map(|s| s.as_str()).collect();
490    let mut visited: HashSet<&str> = HashSet::new();
491    let mut components: Vec<Vec<String>> = Vec::new();
492
493    for node in members {
494        if visited.contains(node.as_str()) {
495            continue;
496        }
497
498        let mut component = Vec::new();
499        let mut queue = VecDeque::new();
500        queue.push_back(node.as_str());
501        visited.insert(node.as_str());
502
503        while let Some(current) = queue.pop_front() {
504            component.push(current.to_string());
505            if let Some(neighbors) = adj.get(current) {
506                for (nbr, _) in neighbors {
507                    if member_set.contains(nbr.as_str()) && !visited.contains(nbr.as_str()) {
508                        visited.insert(nbr.as_str());
509                        queue.push_back(nbr.as_str());
510                    }
511                }
512            }
513        }
514
515        components.push(component);
516    }
517
518    components
519}
520
521/// Compact community IDs to be contiguous starting from 0.
522fn compact_ids(community_of: &mut HashMap<String, usize>) {
523    let mut used: Vec<usize> = community_of
524        .values()
525        .copied()
526        .collect::<HashSet<_>>()
527        .into_iter()
528        .collect();
529    used.sort();
530    let remap: HashMap<usize, usize> = used
531        .iter()
532        .enumerate()
533        .map(|(new_id, &old_id)| (old_id, new_id))
534        .collect();
535    for cid in community_of.values_mut() {
536        *cid = remap[cid];
537    }
538}
539
540/// Merge communities smaller than `MIN_COMMUNITY_SIZE` into their
541/// most-connected neighboring community.
542fn merge_small_communities(
543    communities: &mut HashMap<usize, Vec<String>>,
544    adj: &HashMap<String, Vec<(String, f64)>>,
545) {
546    // Build node → community reverse map once, maintain incrementally
547    let mut node_to_cid: HashMap<String, usize> = communities
548        .iter()
549        .flat_map(|(&cid, nodes)| nodes.iter().map(move |n| (n.clone(), cid)))
550        .collect();
551
552    loop {
553        // Find one small community to merge
554        let merge = communities
555            .iter()
556            .filter(|(_, nodes)| nodes.len() < MIN_COMMUNITY_SIZE)
557            .find_map(|(&small_cid, nodes)| {
558                // Count edges to each neighboring community
559                let mut neighbor_edges: HashMap<usize, f64> = HashMap::new();
560                for node in nodes {
561                    if let Some(neighbors) = adj.get(node.as_str()) {
562                        for (neighbor, weight) in neighbors {
563                            if let Some(&ncid) = node_to_cid.get(neighbor.as_str())
564                                && ncid != small_cid
565                            {
566                                *neighbor_edges.entry(ncid).or_default() += weight;
567                            }
568                        }
569                    }
570                }
571                // Best target
572                neighbor_edges
573                    .iter()
574                    .max_by(|a, b| a.1.total_cmp(b.1))
575                    .map(|(&best_cid, _)| (small_cid, best_cid))
576            });
577
578        match merge {
579            Some((small_cid, best_cid)) => {
580                let nodes = communities.remove(&small_cid).unwrap_or_default();
581                // Update node_to_cid incrementally
582                for node in &nodes {
583                    node_to_cid.insert(node.clone(), best_cid);
584                }
585                communities.entry(best_cid).or_default().extend(nodes);
586            }
587            None => break, // No more small communities to merge
588        }
589    }
590}
591
592/// Try to split an oversized community by running partition on its subgraph.
593fn split_community(graph: &KnowledgeGraph, nodes: &[String]) -> Vec<Vec<String>> {
594    if nodes.len() < MIN_SPLIT_SIZE {
595        return vec![nodes.to_vec()];
596    }
597
598    let node_set: HashSet<&str> = nodes.iter().map(|s| s.as_str()).collect();
599
600    // Build sub-adjacency list
601    let mut sub_adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
602    for node in nodes {
603        sub_adj.entry(node.clone()).or_default();
604    }
605    for (src, tgt, edge) in graph.edges_with_endpoints() {
606        if node_set.contains(src) && node_set.contains(tgt) {
607            sub_adj
608                .entry(src.to_string())
609                .or_default()
610                .push((tgt.to_string(), edge.weight));
611            sub_adj
612                .entry(tgt.to_string())
613                .or_default()
614                .push((src.to_string(), edge.weight));
615        }
616    }
617
618    let m = total_weight(&sub_adj);
619    if m == 0.0 {
620        return nodes.iter().map(|n| vec![n.clone()]).collect();
621    }
622
623    // Run Louvain + refinement on the subgraph
624    let mut community_of: HashMap<String, usize> = nodes
625        .iter()
626        .enumerate()
627        .map(|(i, id)| (id.clone(), i))
628        .collect();
629
630    let node_list: Vec<String> = nodes.to_vec();
631    for _ in 0..5 {
632        let changed = louvain_phase(&sub_adj, &node_list, &mut community_of, m);
633        let refined = refinement_phase(&sub_adj, &mut community_of, m);
634        if !changed && !refined {
635            break;
636        }
637    }
638
639    // Group results
640    let mut groups: HashMap<usize, Vec<String>> = HashMap::new();
641    for (node, cid) in &community_of {
642        groups.entry(*cid).or_default().push(node.clone());
643    }
644
645    let result: Vec<Vec<String>> = groups.into_values().filter(|s| !s.is_empty()).collect();
646
647    if result.len() <= 1 {
648        debug!("could not split community of {} nodes further", nodes.len());
649        return vec![nodes.to_vec()];
650    }
651
652    result
653}
654
655// ---------------------------------------------------------------------------
656// Tests
657// ---------------------------------------------------------------------------
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662    use graphify_core::confidence::Confidence;
663    use graphify_core::graph::KnowledgeGraph;
664    use graphify_core::model::{GraphEdge, GraphNode, NodeType};
665    use std::collections::HashMap as StdMap;
666
667    fn make_node(id: &str) -> GraphNode {
668        GraphNode {
669            id: id.into(),
670            label: id.into(),
671            source_file: "test.rs".into(),
672            source_location: None,
673            node_type: NodeType::Class,
674            community: None,
675            extra: StdMap::new(),
676        }
677    }
678
679    fn make_edge(src: &str, tgt: &str) -> GraphEdge {
680        GraphEdge {
681            source: src.into(),
682            target: tgt.into(),
683            relation: "calls".into(),
684            confidence: Confidence::Extracted,
685            confidence_score: 1.0,
686            source_file: "test.rs".into(),
687            source_location: None,
688            weight: 1.0,
689            extra: StdMap::new(),
690        }
691    }
692
693    fn build_graph(nodes: &[&str], edges: &[(&str, &str)]) -> KnowledgeGraph {
694        let mut g = KnowledgeGraph::new();
695        for &id in nodes {
696            g.add_node(make_node(id)).unwrap();
697        }
698        for &(s, t) in edges {
699            g.add_edge(make_edge(s, t)).unwrap();
700        }
701        g
702    }
703
704    #[test]
705    fn cluster_empty_graph() {
706        let g = KnowledgeGraph::new();
707        let result = cluster(&g);
708        assert!(result.is_empty());
709    }
710
711    #[test]
712    fn cluster_no_edges() {
713        let g = build_graph(&["a", "b", "c"], &[]);
714        let result = cluster(&g);
715        assert_eq!(result.len(), 3);
716        for nodes in result.values() {
717            assert_eq!(nodes.len(), 1);
718        }
719    }
720
721    #[test]
722    fn cluster_single_clique() {
723        let g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
724        let result = cluster(&g);
725        let total_nodes: usize = result.values().map(|v| v.len()).sum();
726        assert_eq!(total_nodes, 3);
727        assert!(result.len() <= 3);
728    }
729
730    #[test]
731    fn cluster_two_cliques() {
732        let g = build_graph(
733            &["a1", "a2", "a3", "b1", "b2", "b3"],
734            &[
735                ("a1", "a2"),
736                ("a2", "a3"),
737                ("a1", "a3"),
738                ("b1", "b2"),
739                ("b2", "b3"),
740                ("b1", "b3"),
741                ("a3", "b1"), // bridge
742            ],
743        );
744        let result = cluster(&g);
745        let total_nodes: usize = result.values().map(|v| v.len()).sum();
746        assert_eq!(total_nodes, 6);
747        assert!(!result.is_empty());
748    }
749
750    #[test]
751    fn cohesion_score_single_node() {
752        let g = build_graph(&["a"], &[]);
753        let score = cohesion_score(&g, &["a".to_string()]);
754        assert!((score - 1.0).abs() < f64::EPSILON);
755    }
756
757    #[test]
758    fn cohesion_score_complete_graph() {
759        let g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
760        let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
761        assert!((score - 1.0).abs() < f64::EPSILON);
762    }
763
764    #[test]
765    fn cohesion_score_no_edges() {
766        let g = build_graph(&["a", "b", "c"], &[]);
767        let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
768        assert!((score - 0.0).abs() < f64::EPSILON);
769    }
770
771    #[test]
772    fn cohesion_score_partial() {
773        let g = build_graph(&["a", "b", "c"], &[("a", "b")]);
774        let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
775        assert!((score - 0.33).abs() < 0.01);
776    }
777
778    #[test]
779    fn score_all_works() {
780        let g = build_graph(&["a", "b"], &[("a", "b")]);
781        let mut communities = HashMap::new();
782        communities.insert(0, vec!["a".to_string(), "b".to_string()]);
783        let scores = score_all(&g, &communities);
784        assert_eq!(scores.len(), 1);
785        assert!((scores[&0] - 1.0).abs() < f64::EPSILON);
786    }
787
788    #[test]
789    fn cluster_graph_mutates_communities() {
790        let mut g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
791        let result = cluster_graph(&mut g);
792        assert!(!result.is_empty());
793        assert!(!g.communities.is_empty());
794    }
795
796    // ── Leiden-specific tests ──
797
798    #[test]
799    fn leiden_splits_disconnected_community() {
800        // Two disconnected cliques — Leiden should guarantee they end up
801        // in separate communities (Louvain might not).
802        let g = build_graph(
803            &["a1", "a2", "a3", "b1", "b2", "b3"],
804            &[
805                ("a1", "a2"),
806                ("a2", "a3"),
807                ("a1", "a3"),
808                ("b1", "b2"),
809                ("b2", "b3"),
810                ("b1", "b3"),
811                // No bridge between the two cliques
812            ],
813        );
814        let result = cluster(&g);
815        // Must have exactly 2 communities
816        assert_eq!(
817            result.len(),
818            2,
819            "disconnected cliques should form 2 communities"
820        );
821        // Each community should have 3 nodes
822        for nodes in result.values() {
823            assert_eq!(nodes.len(), 3);
824        }
825    }
826
827    #[test]
828    fn leiden_connected_components_within() {
829        let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
830        for id in &["a", "b", "c", "d"] {
831            adj.insert(id.to_string(), Vec::new());
832        }
833        // a-b connected, c-d connected, no bridge
834        adj.get_mut("a").unwrap().push(("b".into(), 1.0));
835        adj.get_mut("b").unwrap().push(("a".into(), 1.0));
836        adj.get_mut("c").unwrap().push(("d".into(), 1.0));
837        adj.get_mut("d").unwrap().push(("c".into(), 1.0));
838
839        let members: Vec<String> = vec!["a", "b", "c", "d"]
840            .into_iter()
841            .map(String::from)
842            .collect();
843        let components = connected_components_within(&adj, &members);
844        assert_eq!(components.len(), 2);
845    }
846
847    #[test]
848    fn leiden_single_component() {
849        let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
850        for id in &["a", "b", "c"] {
851            adj.insert(id.to_string(), Vec::new());
852        }
853        adj.get_mut("a").unwrap().push(("b".into(), 1.0));
854        adj.get_mut("b").unwrap().push(("a".into(), 1.0));
855        adj.get_mut("b").unwrap().push(("c".into(), 1.0));
856        adj.get_mut("c").unwrap().push(("b".into(), 1.0));
857
858        let members: Vec<String> = vec!["a", "b", "c"].into_iter().map(String::from).collect();
859        let components = connected_components_within(&adj, &members);
860        assert_eq!(components.len(), 1);
861    }
862}