Skip to main content

graphify_cluster/
lib.rs

1//! Community detection (Leiden algorithm) for graphify.
2//!
3//! Partitions the knowledge graph into communities using the Leiden algorithm,
4//! which improves upon Louvain by adding a refinement phase that guarantees
5//! well-connected communities. Falls back to greedy modularity when refinement
6//! yields no improvement.
7
8use std::collections::{HashMap, HashSet, VecDeque};
9
10use tracing::debug;
11
12use graphify_core::graph::KnowledgeGraph;
13use graphify_core::model::CommunityInfo;
14
15/// Maximum fraction of total nodes a single community may contain before
16/// being split further.
17const MAX_COMMUNITY_FRACTION: f64 = 0.25;
18
19/// Minimum community size below which we never attempt a split.
20const MIN_SPLIT_SIZE: usize = 10;
21
22/// Minimum community size — communities smaller than this get merged into
23/// their most-connected neighbor.
24const MIN_COMMUNITY_SIZE: usize = 5;
25
26/// Resolution parameter for modularity. Lower = fewer, larger communities.
27/// Default Leiden uses 1.0; we use a lower value to avoid over-fragmentation.
28const RESOLUTION: f64 = 0.3;
29
30// ---------------------------------------------------------------------------
31// Public API
32// ---------------------------------------------------------------------------
33
34/// Run community detection on the graph. Returns `{community_id: [node_ids]}`.
35///
36/// Uses the Leiden algorithm: greedy modularity optimization (Louvain phase)
37/// followed by a refinement phase that ensures communities are internally
38/// well-connected.
39pub fn cluster(graph: &KnowledgeGraph) -> HashMap<usize, Vec<String>> {
40    let node_count = graph.node_count();
41    if node_count == 0 {
42        return HashMap::new();
43    }
44
45    // If no edges, each node is its own community
46    if graph.edge_count() == 0 {
47        return graph
48            .node_ids()
49            .into_iter()
50            .enumerate()
51            .map(|(i, id)| (i, vec![id]))
52            .collect();
53    }
54
55    let partition = leiden_partition(graph);
56
57    // Group by community
58    let mut communities: HashMap<usize, Vec<String>> = HashMap::new();
59    for (node_id, cid) in &partition {
60        communities.entry(*cid).or_default().push(node_id.clone());
61    }
62
63    // Merge tiny communities into their most-connected neighbor
64    let adj = build_adjacency(graph);
65    merge_small_communities(&mut communities, &adj);
66
67    // Split oversized communities
68    let max_size = std::cmp::max(
69        MIN_SPLIT_SIZE,
70        (node_count as f64 * MAX_COMMUNITY_FRACTION) as usize,
71    );
72    let mut final_communities: Vec<Vec<String>> = Vec::new();
73    for nodes in communities.values() {
74        if nodes.len() > max_size {
75            final_communities.extend(split_community(graph, nodes));
76        } else {
77            final_communities.push(nodes.clone());
78        }
79    }
80
81    // Re-index by size descending
82    final_communities.sort_by_key(|b| std::cmp::Reverse(b.len()));
83    final_communities
84        .into_iter()
85        .enumerate()
86        .map(|(i, mut nodes)| {
87            nodes.sort();
88            (i, nodes)
89        })
90        .collect()
91}
92
93/// Run community detection and mutate graph in-place, storing community info.
94pub fn cluster_graph(graph: &mut KnowledgeGraph) -> HashMap<usize, Vec<String>> {
95    let communities = cluster(graph);
96
97    // Build CommunityInfo entries
98    let scores = score_all(graph, &communities);
99    let mut infos: Vec<CommunityInfo> = communities
100        .iter()
101        .map(|(&cid, nodes)| CommunityInfo {
102            id: cid,
103            nodes: nodes.clone(),
104            cohesion: scores.get(&cid).copied().unwrap_or(0.0),
105            label: None,
106        })
107        .collect();
108    infos.sort_by_key(|c| c.id);
109    graph.communities = infos;
110
111    communities
112}
113
114/// Cohesion score: ratio of actual intra-community edges to maximum possible.
115///
116/// Returns a value in `[0.0, 1.0]` rounded to two decimal places.
117pub fn cohesion_score(graph: &KnowledgeGraph, community_nodes: &[String]) -> f64 {
118    let n = community_nodes.len();
119    if n <= 1 {
120        return 1.0;
121    }
122
123    let node_set: HashSet<&str> = community_nodes.iter().map(|s| s.as_str()).collect();
124    let mut actual_edges = 0usize;
125
126    // Count edges where both endpoints are in the community
127    for node_id in community_nodes {
128        let neighbors = graph.get_neighbors(node_id);
129        for neighbor in &neighbors {
130            if node_set.contains(neighbor.id.as_str()) {
131                actual_edges += 1;
132            }
133        }
134    }
135    // Each edge is counted twice (once from each endpoint)
136    actual_edges /= 2;
137
138    let possible = n * (n - 1) / 2;
139    if possible == 0 {
140        return 0.0;
141    }
142    ((actual_edges as f64 / possible as f64) * 100.0).round() / 100.0
143}
144
145/// Compute cohesion scores for all communities.
146pub fn score_all(
147    graph: &KnowledgeGraph,
148    communities: &HashMap<usize, Vec<String>>,
149) -> HashMap<usize, f64> {
150    communities
151        .iter()
152        .map(|(&cid, nodes)| (cid, cohesion_score(graph, nodes)))
153        .collect()
154}
155
156// ---------------------------------------------------------------------------
157// Adjacency helpers
158// ---------------------------------------------------------------------------
159
160/// Build an adjacency list from the KnowledgeGraph for efficient lookups.
161fn build_adjacency(graph: &KnowledgeGraph) -> HashMap<String, Vec<(String, f64)>> {
162    let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
163    for id in graph.node_ids() {
164        adj.entry(id).or_default();
165    }
166    for (src, tgt, edge) in graph.edges_with_endpoints() {
167        adj.entry(src.to_string())
168            .or_default()
169            .push((tgt.to_string(), edge.weight));
170        adj.entry(tgt.to_string())
171            .or_default()
172            .push((src.to_string(), edge.weight));
173    }
174    adj
175}
176
177/// Compute total weight of edges in the graph (sum of all edge weights).
178fn total_weight(adj: &HashMap<String, Vec<(String, f64)>>) -> f64 {
179    let mut m = 0.0;
180    for neighbors in adj.values() {
181        for (_, w) in neighbors {
182            m += w;
183        }
184    }
185    m / 2.0 // each edge counted twice
186}
187
188/// Sum of weights of edges incident to a node.
189fn node_strength(adj: &HashMap<String, Vec<(String, f64)>>, node: &str) -> f64 {
190    adj.get(node)
191        .map(|neighbors| neighbors.iter().map(|(_, w)| w).sum())
192        .unwrap_or(0.0)
193}
194
195/// Sum of weights of edges from `node` to nodes in `community`.
196fn edges_to_community(
197    adj: &HashMap<String, Vec<(String, f64)>>,
198    node: &str,
199    community: &HashSet<&str>,
200) -> f64 {
201    adj.get(node)
202        .map(|neighbors| {
203            neighbors
204                .iter()
205                .filter(|(n, _)| community.contains(n.as_str()))
206                .map(|(_, w)| w)
207                .sum()
208        })
209        .unwrap_or(0.0)
210}
211
212/// Sum of strengths of all nodes in a community.
213fn community_strength(adj: &HashMap<String, Vec<(String, f64)>>, members: &HashSet<&str>) -> f64 {
214    members.iter().map(|n| node_strength(adj, n)).sum()
215}
216
217// ---------------------------------------------------------------------------
218// Leiden algorithm
219// ---------------------------------------------------------------------------
220
221/// Leiden algorithm: Louvain phase + refinement phase, iterated until stable.
222///
223/// Reference: Traag, Waltman & van Eck (2019) "From Louvain to Leiden:
224/// guaranteeing well-connected communities"
225fn leiden_partition(graph: &KnowledgeGraph) -> HashMap<String, usize> {
226    let adj = build_adjacency(graph);
227    let m = total_weight(&adj);
228    if m == 0.0 {
229        return graph
230            .node_ids()
231            .into_iter()
232            .enumerate()
233            .map(|(i, id)| (id, i))
234            .collect();
235    }
236
237    let node_ids = graph.node_ids();
238
239    // Initialize: each node in its own community
240    let mut community_of: HashMap<String, usize> = node_ids
241        .iter()
242        .enumerate()
243        .map(|(i, id)| (id.clone(), i))
244        .collect();
245
246    let max_outer_iterations = 10;
247    for _outer in 0..max_outer_iterations {
248        // ── Phase 1: Louvain (greedy modularity move) ──
249        let changed = louvain_phase(&adj, &node_ids, &mut community_of, m);
250
251        // ── Phase 2: Refinement (ensure well-connected communities) ──
252        let refined = refinement_phase(&adj, &mut community_of, m);
253
254        if !changed && !refined {
255            break;
256        }
257    }
258
259    // Compact community IDs
260    compact_ids(&mut community_of);
261    community_of
262}
263
264/// Phase 1: Greedy modularity optimization (Louvain move phase).
265///
266/// Iterates over nodes and moves each to the neighboring community that
267/// yields the greatest modularity gain. Returns true if any move was made.
268fn louvain_phase(
269    adj: &HashMap<String, Vec<(String, f64)>>,
270    node_ids: &[String],
271    community_of: &mut HashMap<String, usize>,
272    m: f64,
273) -> bool {
274    let mut community_members: HashMap<usize, HashSet<String>> = HashMap::new();
275    for (node, &cid) in community_of.iter() {
276        community_members
277            .entry(cid)
278            .or_default()
279            .insert(node.clone());
280    }
281
282    let max_iterations = 50;
283    let mut any_changed = false;
284
285    for _iteration in 0..max_iterations {
286        let mut improved = false;
287
288        for node in node_ids {
289            let current_community = community_of[node];
290            let ki = node_strength(adj, node);
291
292            // Get neighboring communities
293            let mut neighbor_communities: HashSet<usize> = HashSet::new();
294            if let Some(neighbors) = adj.get(node.as_str()) {
295                for (n, _) in neighbors {
296                    neighbor_communities.insert(community_of[n]);
297                }
298            }
299            neighbor_communities.insert(current_community);
300
301            let mut best_community = current_community;
302            let mut best_gain = 0.0f64;
303
304            for &target_community in &neighbor_communities {
305                if target_community == current_community {
306                    continue;
307                }
308
309                let members_ref: HashSet<&str> = community_members
310                    .get(&target_community)
311                    .map(|s| s.iter().map(|x| x.as_str()).collect())
312                    .unwrap_or_default();
313
314                let current_members_ref: HashSet<&str> = community_members
315                    .get(&current_community)
316                    .map(|s| {
317                        s.iter()
318                            .filter(|x| x.as_str() != node.as_str())
319                            .map(|x| x.as_str())
320                            .collect()
321                    })
322                    .unwrap_or_default();
323
324                let ki_in_target = edges_to_community(adj, node, &members_ref);
325                let ki_in_current = edges_to_community(adj, node, &current_members_ref);
326                let sigma_target = community_strength(adj, &members_ref);
327                let sigma_current = community_strength(adj, &current_members_ref);
328
329                let gain = (ki_in_target - ki_in_current) / m
330                    - RESOLUTION * ki * (sigma_target - sigma_current) / (2.0 * m * m);
331
332                if gain > best_gain {
333                    best_gain = gain;
334                    best_community = target_community;
335                }
336            }
337
338            if best_community != current_community {
339                community_members
340                    .get_mut(&current_community)
341                    .unwrap()
342                    .remove(node);
343                community_members
344                    .entry(best_community)
345                    .or_default()
346                    .insert(node.clone());
347                community_of.insert(node.clone(), best_community);
348                improved = true;
349                any_changed = true;
350            }
351        }
352
353        if !improved {
354            break;
355        }
356    }
357
358    any_changed
359}
360
361/// Phase 2: Leiden refinement.
362///
363/// For each community, find its connected components. If a community is
364/// internally disconnected, split it — move each disconnected sub-component
365/// to whichever neighboring community maximizes modularity gain (or keep it
366/// as a new community). This guarantees all resulting communities are
367/// internally connected.
368fn refinement_phase(
369    adj: &HashMap<String, Vec<(String, f64)>>,
370    community_of: &mut HashMap<String, usize>,
371    m: f64,
372) -> bool {
373    // Group nodes by community
374    let mut community_members: HashMap<usize, Vec<String>> = HashMap::new();
375    for (node, &cid) in community_of.iter() {
376        community_members.entry(cid).or_default().push(node.clone());
377    }
378
379    let mut any_refined = false;
380    let mut next_cid = community_members.keys().copied().max().unwrap_or(0) + 1;
381
382    let community_ids: Vec<usize> = community_members.keys().copied().collect();
383    for cid in community_ids {
384        let members = match community_members.get(&cid) {
385            Some(m) if m.len() > 1 => m.clone(),
386            _ => continue,
387        };
388
389        // Find connected components within this community
390        let components = connected_components_within(adj, &members);
391        if components.len() <= 1 {
392            continue; // Already well-connected
393        }
394
395        debug!(
396            "Leiden refinement: community {} has {} disconnected components, splitting",
397            cid,
398            components.len()
399        );
400
401        // Keep the largest component in the original community,
402        // assign each smaller component to the best neighboring community
403        // or a new community.
404        let mut sorted_components = components;
405        sorted_components.sort_by_key(|c| std::cmp::Reverse(c.len()));
406
407        // Largest component stays
408        for component in sorted_components.iter().skip(1) {
409            // For this sub-component, find the best neighboring community
410            let mut neighbor_cid_edges: HashMap<usize, f64> = HashMap::new();
411            for node in component {
412                if let Some(neighbors) = adj.get(node.as_str()) {
413                    for (nbr, w) in neighbors {
414                        let nbr_cid = community_of[nbr];
415                        if nbr_cid != cid {
416                            *neighbor_cid_edges.entry(nbr_cid).or_default() += w;
417                        }
418                    }
419                }
420            }
421
422            // Pick the neighbor community with the strongest connection,
423            // or create a new community if no neighbor exists
424            let target_cid = if let Some((&best_cid, _)) = neighbor_cid_edges
425                .iter()
426                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
427            {
428                // Only merge if modularity gain is positive
429                let _component_set: HashSet<&str> = component.iter().map(|s| s.as_str()).collect();
430                let target_members: HashSet<&str> = community_members
431                    .get(&best_cid)
432                    .map(|s| s.iter().map(|x| x.as_str()).collect())
433                    .unwrap_or_default();
434
435                let ki_sum: f64 = component.iter().map(|n| node_strength(adj, n)).sum();
436                let ki_in = component
437                    .iter()
438                    .map(|n| edges_to_community(adj, n, &target_members))
439                    .sum::<f64>();
440                let sigma_t = community_strength(adj, &target_members);
441
442                let gain = ki_in / m - ki_sum * sigma_t / (2.0 * m * m);
443                if gain > 0.0 {
444                    best_cid
445                } else {
446                    let new_cid = next_cid;
447                    next_cid += 1;
448                    new_cid
449                }
450            } else {
451                let new_cid = next_cid;
452                next_cid += 1;
453                new_cid
454            };
455
456            // Move all nodes in this component
457            for node in component {
458                community_of.insert(node.clone(), target_cid);
459                community_members
460                    .entry(target_cid)
461                    .or_default()
462                    .push(node.clone());
463            }
464            any_refined = true;
465        }
466
467        // Update original community to only keep largest component
468        if any_refined {
469            community_members.insert(cid, sorted_components.into_iter().next().unwrap());
470        }
471    }
472
473    any_refined
474}
475
476/// Find connected components within a subset of nodes using BFS.
477fn connected_components_within(
478    adj: &HashMap<String, Vec<(String, f64)>>,
479    members: &[String],
480) -> Vec<Vec<String>> {
481    let member_set: HashSet<&str> = members.iter().map(|s| s.as_str()).collect();
482    let mut visited: HashSet<&str> = HashSet::new();
483    let mut components: Vec<Vec<String>> = Vec::new();
484
485    for node in members {
486        if visited.contains(node.as_str()) {
487            continue;
488        }
489
490        let mut component = Vec::new();
491        let mut queue = VecDeque::new();
492        queue.push_back(node.as_str());
493        visited.insert(node.as_str());
494
495        while let Some(current) = queue.pop_front() {
496            component.push(current.to_string());
497            if let Some(neighbors) = adj.get(current) {
498                for (nbr, _) in neighbors {
499                    if member_set.contains(nbr.as_str()) && !visited.contains(nbr.as_str()) {
500                        visited.insert(nbr.as_str());
501                        queue.push_back(nbr.as_str());
502                    }
503                }
504            }
505        }
506
507        components.push(component);
508    }
509
510    components
511}
512
513/// Compact community IDs to be contiguous starting from 0.
514fn compact_ids(community_of: &mut HashMap<String, usize>) {
515    let mut used: Vec<usize> = community_of
516        .values()
517        .copied()
518        .collect::<HashSet<_>>()
519        .into_iter()
520        .collect();
521    used.sort();
522    let remap: HashMap<usize, usize> = used
523        .iter()
524        .enumerate()
525        .map(|(new_id, &old_id)| (old_id, new_id))
526        .collect();
527    for cid in community_of.values_mut() {
528        *cid = remap[cid];
529    }
530}
531
532/// Merge communities smaller than `MIN_COMMUNITY_SIZE` into their
533/// most-connected neighboring community.
534fn merge_small_communities(
535    communities: &mut HashMap<usize, Vec<String>>,
536    adj: &HashMap<String, Vec<(String, f64)>>,
537) {
538    loop {
539        // Build node → community reverse map each iteration
540        let node_to_cid: HashMap<String, usize> = communities
541            .iter()
542            .flat_map(|(&cid, nodes)| nodes.iter().map(move |n| (n.clone(), cid)))
543            .collect();
544
545        // Find one small community to merge
546        let merge = communities
547            .iter()
548            .filter(|(_, nodes)| nodes.len() < MIN_COMMUNITY_SIZE)
549            .find_map(|(&small_cid, nodes)| {
550                // Count edges to each neighboring community
551                let mut neighbor_edges: HashMap<usize, f64> = HashMap::new();
552                for node in nodes {
553                    if let Some(neighbors) = adj.get(node.as_str()) {
554                        for (neighbor, weight) in neighbors {
555                            if let Some(&ncid) = node_to_cid.get(neighbor.as_str())
556                                && ncid != small_cid
557                            {
558                                *neighbor_edges.entry(ncid).or_default() += weight;
559                            }
560                        }
561                    }
562                }
563                // Best target
564                neighbor_edges
565                    .iter()
566                    .max_by(|a, b| a.1.total_cmp(b.1))
567                    .map(|(&best_cid, _)| (small_cid, best_cid))
568            });
569
570        match merge {
571            Some((small_cid, best_cid)) => {
572                let nodes = communities.remove(&small_cid).unwrap_or_default();
573                communities.entry(best_cid).or_default().extend(nodes);
574            }
575            None => break, // No more small communities to merge
576        }
577    }
578}
579
580/// Try to split an oversized community by running partition on its subgraph.
581fn split_community(graph: &KnowledgeGraph, nodes: &[String]) -> Vec<Vec<String>> {
582    if nodes.len() < MIN_SPLIT_SIZE {
583        return vec![nodes.to_vec()];
584    }
585
586    let node_set: HashSet<&str> = nodes.iter().map(|s| s.as_str()).collect();
587
588    // Build sub-adjacency list
589    let mut sub_adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
590    for node in nodes {
591        sub_adj.entry(node.clone()).or_default();
592    }
593    for (src, tgt, edge) in graph.edges_with_endpoints() {
594        if node_set.contains(src) && node_set.contains(tgt) {
595            sub_adj
596                .entry(src.to_string())
597                .or_default()
598                .push((tgt.to_string(), edge.weight));
599            sub_adj
600                .entry(tgt.to_string())
601                .or_default()
602                .push((src.to_string(), edge.weight));
603        }
604    }
605
606    let m = total_weight(&sub_adj);
607    if m == 0.0 {
608        return nodes.iter().map(|n| vec![n.clone()]).collect();
609    }
610
611    // Run Louvain + refinement on the subgraph
612    let mut community_of: HashMap<String, usize> = nodes
613        .iter()
614        .enumerate()
615        .map(|(i, id)| (id.clone(), i))
616        .collect();
617
618    let node_list: Vec<String> = nodes.to_vec();
619    for _ in 0..5 {
620        let changed = louvain_phase(&sub_adj, &node_list, &mut community_of, m);
621        let refined = refinement_phase(&sub_adj, &mut community_of, m);
622        if !changed && !refined {
623            break;
624        }
625    }
626
627    // Group results
628    let mut groups: HashMap<usize, Vec<String>> = HashMap::new();
629    for (node, cid) in &community_of {
630        groups.entry(*cid).or_default().push(node.clone());
631    }
632
633    let result: Vec<Vec<String>> = groups.into_values().filter(|s| !s.is_empty()).collect();
634
635    if result.len() <= 1 {
636        debug!("could not split community of {} nodes further", nodes.len());
637        return vec![nodes.to_vec()];
638    }
639
640    result
641}
642
643// ---------------------------------------------------------------------------
644// Tests
645// ---------------------------------------------------------------------------
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use graphify_core::confidence::Confidence;
651    use graphify_core::graph::KnowledgeGraph;
652    use graphify_core::model::{GraphEdge, GraphNode, NodeType};
653    use std::collections::HashMap as StdMap;
654
655    fn make_node(id: &str) -> GraphNode {
656        GraphNode {
657            id: id.into(),
658            label: id.into(),
659            source_file: "test.rs".into(),
660            source_location: None,
661            node_type: NodeType::Class,
662            community: None,
663            extra: StdMap::new(),
664        }
665    }
666
667    fn make_edge(src: &str, tgt: &str) -> GraphEdge {
668        GraphEdge {
669            source: src.into(),
670            target: tgt.into(),
671            relation: "calls".into(),
672            confidence: Confidence::Extracted,
673            confidence_score: 1.0,
674            source_file: "test.rs".into(),
675            source_location: None,
676            weight: 1.0,
677            extra: StdMap::new(),
678        }
679    }
680
681    fn build_graph(nodes: &[&str], edges: &[(&str, &str)]) -> KnowledgeGraph {
682        let mut g = KnowledgeGraph::new();
683        for &id in nodes {
684            g.add_node(make_node(id)).unwrap();
685        }
686        for &(s, t) in edges {
687            g.add_edge(make_edge(s, t)).unwrap();
688        }
689        g
690    }
691
692    #[test]
693    fn cluster_empty_graph() {
694        let g = KnowledgeGraph::new();
695        let result = cluster(&g);
696        assert!(result.is_empty());
697    }
698
699    #[test]
700    fn cluster_no_edges() {
701        let g = build_graph(&["a", "b", "c"], &[]);
702        let result = cluster(&g);
703        assert_eq!(result.len(), 3);
704        for nodes in result.values() {
705            assert_eq!(nodes.len(), 1);
706        }
707    }
708
709    #[test]
710    fn cluster_single_clique() {
711        let g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
712        let result = cluster(&g);
713        let total_nodes: usize = result.values().map(|v| v.len()).sum();
714        assert_eq!(total_nodes, 3);
715        assert!(result.len() <= 3);
716    }
717
718    #[test]
719    fn cluster_two_cliques() {
720        let g = build_graph(
721            &["a1", "a2", "a3", "b1", "b2", "b3"],
722            &[
723                ("a1", "a2"),
724                ("a2", "a3"),
725                ("a1", "a3"),
726                ("b1", "b2"),
727                ("b2", "b3"),
728                ("b1", "b3"),
729                ("a3", "b1"), // bridge
730            ],
731        );
732        let result = cluster(&g);
733        let total_nodes: usize = result.values().map(|v| v.len()).sum();
734        assert_eq!(total_nodes, 6);
735        assert!(!result.is_empty());
736    }
737
738    #[test]
739    fn cohesion_score_single_node() {
740        let g = build_graph(&["a"], &[]);
741        let score = cohesion_score(&g, &["a".to_string()]);
742        assert!((score - 1.0).abs() < f64::EPSILON);
743    }
744
745    #[test]
746    fn cohesion_score_complete_graph() {
747        let g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
748        let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
749        assert!((score - 1.0).abs() < f64::EPSILON);
750    }
751
752    #[test]
753    fn cohesion_score_no_edges() {
754        let g = build_graph(&["a", "b", "c"], &[]);
755        let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
756        assert!((score - 0.0).abs() < f64::EPSILON);
757    }
758
759    #[test]
760    fn cohesion_score_partial() {
761        let g = build_graph(&["a", "b", "c"], &[("a", "b")]);
762        let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
763        assert!((score - 0.33).abs() < 0.01);
764    }
765
766    #[test]
767    fn score_all_works() {
768        let g = build_graph(&["a", "b"], &[("a", "b")]);
769        let mut communities = HashMap::new();
770        communities.insert(0, vec!["a".to_string(), "b".to_string()]);
771        let scores = score_all(&g, &communities);
772        assert_eq!(scores.len(), 1);
773        assert!((scores[&0] - 1.0).abs() < f64::EPSILON);
774    }
775
776    #[test]
777    fn cluster_graph_mutates_communities() {
778        let mut g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
779        let result = cluster_graph(&mut g);
780        assert!(!result.is_empty());
781        assert!(!g.communities.is_empty());
782    }
783
784    // ── Leiden-specific tests ──
785
786    #[test]
787    fn leiden_splits_disconnected_community() {
788        // Two disconnected cliques — Leiden should guarantee they end up
789        // in separate communities (Louvain might not).
790        let g = build_graph(
791            &["a1", "a2", "a3", "b1", "b2", "b3"],
792            &[
793                ("a1", "a2"),
794                ("a2", "a3"),
795                ("a1", "a3"),
796                ("b1", "b2"),
797                ("b2", "b3"),
798                ("b1", "b3"),
799                // No bridge between the two cliques
800            ],
801        );
802        let result = cluster(&g);
803        // Must have exactly 2 communities
804        assert_eq!(
805            result.len(),
806            2,
807            "disconnected cliques should form 2 communities"
808        );
809        // Each community should have 3 nodes
810        for nodes in result.values() {
811            assert_eq!(nodes.len(), 3);
812        }
813    }
814
815    #[test]
816    fn leiden_connected_components_within() {
817        let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
818        for id in &["a", "b", "c", "d"] {
819            adj.insert(id.to_string(), Vec::new());
820        }
821        // a-b connected, c-d connected, no bridge
822        adj.get_mut("a").unwrap().push(("b".into(), 1.0));
823        adj.get_mut("b").unwrap().push(("a".into(), 1.0));
824        adj.get_mut("c").unwrap().push(("d".into(), 1.0));
825        adj.get_mut("d").unwrap().push(("c".into(), 1.0));
826
827        let members: Vec<String> = vec!["a", "b", "c", "d"]
828            .into_iter()
829            .map(String::from)
830            .collect();
831        let components = connected_components_within(&adj, &members);
832        assert_eq!(components.len(), 2);
833    }
834
835    #[test]
836    fn leiden_single_component() {
837        let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
838        for id in &["a", "b", "c"] {
839            adj.insert(id.to_string(), Vec::new());
840        }
841        adj.get_mut("a").unwrap().push(("b".into(), 1.0));
842        adj.get_mut("b").unwrap().push(("a".into(), 1.0));
843        adj.get_mut("b").unwrap().push(("c".into(), 1.0));
844        adj.get_mut("c").unwrap().push(("b".into(), 1.0));
845
846        let members: Vec<String> = vec!["a", "b", "c"].into_iter().map(String::from).collect();
847        let components = connected_components_within(&adj, &members);
848        assert_eq!(components.len(), 1);
849    }
850}