Skip to main content

graphify_cluster/
lib.rs

1//! Community detection (Leiden algorithm) for graphify.
2//!
3//! Partitions the knowledge graph into communities using the Leiden algorithm,
4//! which improves upon Louvain by adding a refinement phase that guarantees
5//! well-connected communities. Falls back to greedy modularity when refinement
6//! yields no improvement.
7
8use std::collections::{HashMap, HashSet, VecDeque};
9
10use tracing::debug;
11
12use graphify_core::graph::KnowledgeGraph;
13use graphify_core::model::CommunityInfo;
14
15/// Maximum fraction of total nodes a single community may contain before
16/// being split further.
17const MAX_COMMUNITY_FRACTION: f64 = 0.25;
18
19/// Minimum community size below which we never attempt a split.
20const MIN_SPLIT_SIZE: usize = 10;
21
22// ---------------------------------------------------------------------------
23// Public API
24// ---------------------------------------------------------------------------
25
26/// Run community detection on the graph. Returns `{community_id: [node_ids]}`.
27///
28/// Uses the Leiden algorithm: greedy modularity optimization (Louvain phase)
29/// followed by a refinement phase that ensures communities are internally
30/// well-connected.
31pub fn cluster(graph: &KnowledgeGraph) -> HashMap<usize, Vec<String>> {
32    let node_count = graph.node_count();
33    if node_count == 0 {
34        return HashMap::new();
35    }
36
37    // If no edges, each node is its own community
38    if graph.edge_count() == 0 {
39        return graph
40            .node_ids()
41            .into_iter()
42            .enumerate()
43            .map(|(i, id)| (i, vec![id]))
44            .collect();
45    }
46
47    let partition = leiden_partition(graph);
48
49    // Group by community
50    let mut communities: HashMap<usize, Vec<String>> = HashMap::new();
51    for (node_id, cid) in &partition {
52        communities.entry(*cid).or_default().push(node_id.clone());
53    }
54
55    // Split oversized communities
56    let max_size = std::cmp::max(
57        MIN_SPLIT_SIZE,
58        (node_count as f64 * MAX_COMMUNITY_FRACTION) as usize,
59    );
60    let mut final_communities: Vec<Vec<String>> = Vec::new();
61    for nodes in communities.values() {
62        if nodes.len() > max_size {
63            final_communities.extend(split_community(graph, nodes));
64        } else {
65            final_communities.push(nodes.clone());
66        }
67    }
68
69    // Re-index by size descending
70    final_communities.sort_by_key(|b| std::cmp::Reverse(b.len()));
71    final_communities
72        .into_iter()
73        .enumerate()
74        .map(|(i, mut nodes)| {
75            nodes.sort();
76            (i, nodes)
77        })
78        .collect()
79}
80
81/// Run community detection and mutate graph in-place, storing community info.
82pub fn cluster_graph(graph: &mut KnowledgeGraph) -> HashMap<usize, Vec<String>> {
83    let communities = cluster(graph);
84
85    // Build CommunityInfo entries
86    let scores = score_all(graph, &communities);
87    let mut infos: Vec<CommunityInfo> = communities
88        .iter()
89        .map(|(&cid, nodes)| CommunityInfo {
90            id: cid,
91            nodes: nodes.clone(),
92            cohesion: scores.get(&cid).copied().unwrap_or(0.0),
93            label: None,
94        })
95        .collect();
96    infos.sort_by_key(|c| c.id);
97    graph.communities = infos;
98
99    communities
100}
101
102/// Cohesion score: ratio of actual intra-community edges to maximum possible.
103///
104/// Returns a value in `[0.0, 1.0]` rounded to two decimal places.
105pub fn cohesion_score(graph: &KnowledgeGraph, community_nodes: &[String]) -> f64 {
106    let n = community_nodes.len();
107    if n <= 1 {
108        return 1.0;
109    }
110
111    let node_set: HashSet<&str> = community_nodes.iter().map(|s| s.as_str()).collect();
112    let mut actual_edges = 0usize;
113
114    // Count edges where both endpoints are in the community
115    for node_id in community_nodes {
116        let neighbors = graph.get_neighbors(node_id);
117        for neighbor in &neighbors {
118            if node_set.contains(neighbor.id.as_str()) {
119                actual_edges += 1;
120            }
121        }
122    }
123    // Each edge is counted twice (once from each endpoint)
124    actual_edges /= 2;
125
126    let possible = n * (n - 1) / 2;
127    if possible == 0 {
128        return 0.0;
129    }
130    ((actual_edges as f64 / possible as f64) * 100.0).round() / 100.0
131}
132
133/// Compute cohesion scores for all communities.
134pub fn score_all(
135    graph: &KnowledgeGraph,
136    communities: &HashMap<usize, Vec<String>>,
137) -> HashMap<usize, f64> {
138    communities
139        .iter()
140        .map(|(&cid, nodes)| (cid, cohesion_score(graph, nodes)))
141        .collect()
142}
143
144// ---------------------------------------------------------------------------
145// Adjacency helpers
146// ---------------------------------------------------------------------------
147
148/// Build an adjacency list from the KnowledgeGraph for efficient lookups.
149fn build_adjacency(graph: &KnowledgeGraph) -> HashMap<String, Vec<(String, f64)>> {
150    let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
151    for id in graph.node_ids() {
152        adj.entry(id).or_default();
153    }
154    for (src, tgt, edge) in graph.edges_with_endpoints() {
155        adj.entry(src.to_string())
156            .or_default()
157            .push((tgt.to_string(), edge.weight));
158        adj.entry(tgt.to_string())
159            .or_default()
160            .push((src.to_string(), edge.weight));
161    }
162    adj
163}
164
165/// Compute total weight of edges in the graph (sum of all edge weights).
166fn total_weight(adj: &HashMap<String, Vec<(String, f64)>>) -> f64 {
167    let mut m = 0.0;
168    for neighbors in adj.values() {
169        for (_, w) in neighbors {
170            m += w;
171        }
172    }
173    m / 2.0 // each edge counted twice
174}
175
176/// Sum of weights of edges incident to a node.
177fn node_strength(adj: &HashMap<String, Vec<(String, f64)>>, node: &str) -> f64 {
178    adj.get(node)
179        .map(|neighbors| neighbors.iter().map(|(_, w)| w).sum())
180        .unwrap_or(0.0)
181}
182
183/// Sum of weights of edges from `node` to nodes in `community`.
184fn edges_to_community(
185    adj: &HashMap<String, Vec<(String, f64)>>,
186    node: &str,
187    community: &HashSet<&str>,
188) -> f64 {
189    adj.get(node)
190        .map(|neighbors| {
191            neighbors
192                .iter()
193                .filter(|(n, _)| community.contains(n.as_str()))
194                .map(|(_, w)| w)
195                .sum()
196        })
197        .unwrap_or(0.0)
198}
199
200/// Sum of strengths of all nodes in a community.
201fn community_strength(adj: &HashMap<String, Vec<(String, f64)>>, members: &HashSet<&str>) -> f64 {
202    members.iter().map(|n| node_strength(adj, n)).sum()
203}
204
205// ---------------------------------------------------------------------------
206// Leiden algorithm
207// ---------------------------------------------------------------------------
208
209/// Leiden algorithm: Louvain phase + refinement phase, iterated until stable.
210///
211/// Reference: Traag, Waltman & van Eck (2019) "From Louvain to Leiden:
212/// guaranteeing well-connected communities"
213fn leiden_partition(graph: &KnowledgeGraph) -> HashMap<String, usize> {
214    let adj = build_adjacency(graph);
215    let m = total_weight(&adj);
216    if m == 0.0 {
217        return graph
218            .node_ids()
219            .into_iter()
220            .enumerate()
221            .map(|(i, id)| (id, i))
222            .collect();
223    }
224
225    let node_ids = graph.node_ids();
226
227    // Initialize: each node in its own community
228    let mut community_of: HashMap<String, usize> = node_ids
229        .iter()
230        .enumerate()
231        .map(|(i, id)| (id.clone(), i))
232        .collect();
233
234    let max_outer_iterations = 10;
235    for _outer in 0..max_outer_iterations {
236        // ── Phase 1: Louvain (greedy modularity move) ──
237        let changed = louvain_phase(&adj, &node_ids, &mut community_of, m);
238
239        // ── Phase 2: Refinement (ensure well-connected communities) ──
240        let refined = refinement_phase(&adj, &mut community_of, m);
241
242        if !changed && !refined {
243            break;
244        }
245    }
246
247    // Compact community IDs
248    compact_ids(&mut community_of);
249    community_of
250}
251
252/// Phase 1: Greedy modularity optimization (Louvain move phase).
253///
254/// Iterates over nodes and moves each to the neighboring community that
255/// yields the greatest modularity gain. Returns true if any move was made.
256fn louvain_phase(
257    adj: &HashMap<String, Vec<(String, f64)>>,
258    node_ids: &[String],
259    community_of: &mut HashMap<String, usize>,
260    m: f64,
261) -> bool {
262    let mut community_members: HashMap<usize, HashSet<String>> = HashMap::new();
263    for (node, &cid) in community_of.iter() {
264        community_members
265            .entry(cid)
266            .or_default()
267            .insert(node.clone());
268    }
269
270    let max_iterations = 50;
271    let mut any_changed = false;
272
273    for _iteration in 0..max_iterations {
274        let mut improved = false;
275
276        for node in node_ids {
277            let current_community = community_of[node];
278            let ki = node_strength(adj, node);
279
280            // Get neighboring communities
281            let mut neighbor_communities: HashSet<usize> = HashSet::new();
282            if let Some(neighbors) = adj.get(node.as_str()) {
283                for (n, _) in neighbors {
284                    neighbor_communities.insert(community_of[n]);
285                }
286            }
287            neighbor_communities.insert(current_community);
288
289            let mut best_community = current_community;
290            let mut best_gain = 0.0f64;
291
292            for &target_community in &neighbor_communities {
293                if target_community == current_community {
294                    continue;
295                }
296
297                let members_ref: HashSet<&str> = community_members
298                    .get(&target_community)
299                    .map(|s| s.iter().map(|x| x.as_str()).collect())
300                    .unwrap_or_default();
301
302                let current_members_ref: HashSet<&str> = community_members
303                    .get(&current_community)
304                    .map(|s| {
305                        s.iter()
306                            .filter(|x| x.as_str() != node.as_str())
307                            .map(|x| x.as_str())
308                            .collect()
309                    })
310                    .unwrap_or_default();
311
312                let ki_in_target = edges_to_community(adj, node, &members_ref);
313                let ki_in_current = edges_to_community(adj, node, &current_members_ref);
314                let sigma_target = community_strength(adj, &members_ref);
315                let sigma_current = community_strength(adj, &current_members_ref);
316
317                let gain = (ki_in_target - ki_in_current) / m
318                    - ki * (sigma_target - sigma_current) / (2.0 * m * m);
319
320                if gain > best_gain {
321                    best_gain = gain;
322                    best_community = target_community;
323                }
324            }
325
326            if best_community != current_community {
327                community_members
328                    .get_mut(&current_community)
329                    .unwrap()
330                    .remove(node);
331                community_members
332                    .entry(best_community)
333                    .or_default()
334                    .insert(node.clone());
335                community_of.insert(node.clone(), best_community);
336                improved = true;
337                any_changed = true;
338            }
339        }
340
341        if !improved {
342            break;
343        }
344    }
345
346    any_changed
347}
348
349/// Phase 2: Leiden refinement.
350///
351/// For each community, find its connected components. If a community is
352/// internally disconnected, split it — move each disconnected sub-component
353/// to whichever neighboring community maximizes modularity gain (or keep it
354/// as a new community). This guarantees all resulting communities are
355/// internally connected.
356fn refinement_phase(
357    adj: &HashMap<String, Vec<(String, f64)>>,
358    community_of: &mut HashMap<String, usize>,
359    m: f64,
360) -> bool {
361    // Group nodes by community
362    let mut community_members: HashMap<usize, Vec<String>> = HashMap::new();
363    for (node, &cid) in community_of.iter() {
364        community_members.entry(cid).or_default().push(node.clone());
365    }
366
367    let mut any_refined = false;
368    let mut next_cid = community_members.keys().copied().max().unwrap_or(0) + 1;
369
370    let community_ids: Vec<usize> = community_members.keys().copied().collect();
371    for cid in community_ids {
372        let members = match community_members.get(&cid) {
373            Some(m) if m.len() > 1 => m.clone(),
374            _ => continue,
375        };
376
377        // Find connected components within this community
378        let components = connected_components_within(adj, &members);
379        if components.len() <= 1 {
380            continue; // Already well-connected
381        }
382
383        debug!(
384            "Leiden refinement: community {} has {} disconnected components, splitting",
385            cid,
386            components.len()
387        );
388
389        // Keep the largest component in the original community,
390        // assign each smaller component to the best neighboring community
391        // or a new community.
392        let mut sorted_components = components;
393        sorted_components.sort_by_key(|c| std::cmp::Reverse(c.len()));
394
395        // Largest component stays
396        for component in sorted_components.iter().skip(1) {
397            // For this sub-component, find the best neighboring community
398            let mut neighbor_cid_edges: HashMap<usize, f64> = HashMap::new();
399            for node in component {
400                if let Some(neighbors) = adj.get(node.as_str()) {
401                    for (nbr, w) in neighbors {
402                        let nbr_cid = community_of[nbr];
403                        if nbr_cid != cid {
404                            *neighbor_cid_edges.entry(nbr_cid).or_default() += w;
405                        }
406                    }
407                }
408            }
409
410            // Pick the neighbor community with the strongest connection,
411            // or create a new community if no neighbor exists
412            let target_cid = if let Some((&best_cid, _)) = neighbor_cid_edges
413                .iter()
414                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
415            {
416                // Only merge if modularity gain is positive
417                let _component_set: HashSet<&str> = component.iter().map(|s| s.as_str()).collect();
418                let target_members: HashSet<&str> = community_members
419                    .get(&best_cid)
420                    .map(|s| s.iter().map(|x| x.as_str()).collect())
421                    .unwrap_or_default();
422
423                let ki_sum: f64 = component.iter().map(|n| node_strength(adj, n)).sum();
424                let ki_in = component
425                    .iter()
426                    .map(|n| edges_to_community(adj, n, &target_members))
427                    .sum::<f64>();
428                let sigma_t = community_strength(adj, &target_members);
429
430                let gain = ki_in / m - ki_sum * sigma_t / (2.0 * m * m);
431                if gain > 0.0 {
432                    best_cid
433                } else {
434                    let new_cid = next_cid;
435                    next_cid += 1;
436                    new_cid
437                }
438            } else {
439                let new_cid = next_cid;
440                next_cid += 1;
441                new_cid
442            };
443
444            // Move all nodes in this component
445            for node in component {
446                community_of.insert(node.clone(), target_cid);
447                community_members
448                    .entry(target_cid)
449                    .or_default()
450                    .push(node.clone());
451            }
452            any_refined = true;
453        }
454
455        // Update original community to only keep largest component
456        if any_refined {
457            community_members.insert(cid, sorted_components.into_iter().next().unwrap());
458        }
459    }
460
461    any_refined
462}
463
464/// Find connected components within a subset of nodes using BFS.
465fn connected_components_within(
466    adj: &HashMap<String, Vec<(String, f64)>>,
467    members: &[String],
468) -> Vec<Vec<String>> {
469    let member_set: HashSet<&str> = members.iter().map(|s| s.as_str()).collect();
470    let mut visited: HashSet<&str> = HashSet::new();
471    let mut components: Vec<Vec<String>> = Vec::new();
472
473    for node in members {
474        if visited.contains(node.as_str()) {
475            continue;
476        }
477
478        let mut component = Vec::new();
479        let mut queue = VecDeque::new();
480        queue.push_back(node.as_str());
481        visited.insert(node.as_str());
482
483        while let Some(current) = queue.pop_front() {
484            component.push(current.to_string());
485            if let Some(neighbors) = adj.get(current) {
486                for (nbr, _) in neighbors {
487                    if member_set.contains(nbr.as_str()) && !visited.contains(nbr.as_str()) {
488                        visited.insert(nbr.as_str());
489                        queue.push_back(nbr.as_str());
490                    }
491                }
492            }
493        }
494
495        components.push(component);
496    }
497
498    components
499}
500
501/// Compact community IDs to be contiguous starting from 0.
502fn compact_ids(community_of: &mut HashMap<String, usize>) {
503    let mut used: Vec<usize> = community_of
504        .values()
505        .copied()
506        .collect::<HashSet<_>>()
507        .into_iter()
508        .collect();
509    used.sort();
510    let remap: HashMap<usize, usize> = used
511        .iter()
512        .enumerate()
513        .map(|(new_id, &old_id)| (old_id, new_id))
514        .collect();
515    for cid in community_of.values_mut() {
516        *cid = remap[cid];
517    }
518}
519
520/// Try to split an oversized community by running partition on its subgraph.
521fn split_community(graph: &KnowledgeGraph, nodes: &[String]) -> Vec<Vec<String>> {
522    if nodes.len() < MIN_SPLIT_SIZE {
523        return vec![nodes.to_vec()];
524    }
525
526    let node_set: HashSet<&str> = nodes.iter().map(|s| s.as_str()).collect();
527
528    // Build sub-adjacency list
529    let mut sub_adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
530    for node in nodes {
531        sub_adj.entry(node.clone()).or_default();
532    }
533    for (src, tgt, edge) in graph.edges_with_endpoints() {
534        if node_set.contains(src) && node_set.contains(tgt) {
535            sub_adj
536                .entry(src.to_string())
537                .or_default()
538                .push((tgt.to_string(), edge.weight));
539            sub_adj
540                .entry(tgt.to_string())
541                .or_default()
542                .push((src.to_string(), edge.weight));
543        }
544    }
545
546    let m = total_weight(&sub_adj);
547    if m == 0.0 {
548        return nodes.iter().map(|n| vec![n.clone()]).collect();
549    }
550
551    // Run Louvain + refinement on the subgraph
552    let mut community_of: HashMap<String, usize> = nodes
553        .iter()
554        .enumerate()
555        .map(|(i, id)| (id.clone(), i))
556        .collect();
557
558    let node_list: Vec<String> = nodes.to_vec();
559    for _ in 0..5 {
560        let changed = louvain_phase(&sub_adj, &node_list, &mut community_of, m);
561        let refined = refinement_phase(&sub_adj, &mut community_of, m);
562        if !changed && !refined {
563            break;
564        }
565    }
566
567    // Group results
568    let mut groups: HashMap<usize, Vec<String>> = HashMap::new();
569    for (node, cid) in &community_of {
570        groups.entry(*cid).or_default().push(node.clone());
571    }
572
573    let result: Vec<Vec<String>> = groups.into_values().filter(|s| !s.is_empty()).collect();
574
575    if result.len() <= 1 {
576        debug!("could not split community of {} nodes further", nodes.len());
577        return vec![nodes.to_vec()];
578    }
579
580    result
581}
582
583// ---------------------------------------------------------------------------
584// Tests
585// ---------------------------------------------------------------------------
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590    use graphify_core::confidence::Confidence;
591    use graphify_core::graph::KnowledgeGraph;
592    use graphify_core::model::{GraphEdge, GraphNode, NodeType};
593    use std::collections::HashMap as StdMap;
594
595    fn make_node(id: &str) -> GraphNode {
596        GraphNode {
597            id: id.into(),
598            label: id.into(),
599            source_file: "test.rs".into(),
600            source_location: None,
601            node_type: NodeType::Class,
602            community: None,
603            extra: StdMap::new(),
604        }
605    }
606
607    fn make_edge(src: &str, tgt: &str) -> GraphEdge {
608        GraphEdge {
609            source: src.into(),
610            target: tgt.into(),
611            relation: "calls".into(),
612            confidence: Confidence::Extracted,
613            confidence_score: 1.0,
614            source_file: "test.rs".into(),
615            source_location: None,
616            weight: 1.0,
617            extra: StdMap::new(),
618        }
619    }
620
621    fn build_graph(nodes: &[&str], edges: &[(&str, &str)]) -> KnowledgeGraph {
622        let mut g = KnowledgeGraph::new();
623        for &id in nodes {
624            g.add_node(make_node(id)).unwrap();
625        }
626        for &(s, t) in edges {
627            g.add_edge(make_edge(s, t)).unwrap();
628        }
629        g
630    }
631
632    #[test]
633    fn cluster_empty_graph() {
634        let g = KnowledgeGraph::new();
635        let result = cluster(&g);
636        assert!(result.is_empty());
637    }
638
639    #[test]
640    fn cluster_no_edges() {
641        let g = build_graph(&["a", "b", "c"], &[]);
642        let result = cluster(&g);
643        assert_eq!(result.len(), 3);
644        for nodes in result.values() {
645            assert_eq!(nodes.len(), 1);
646        }
647    }
648
649    #[test]
650    fn cluster_single_clique() {
651        let g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
652        let result = cluster(&g);
653        let total_nodes: usize = result.values().map(|v| v.len()).sum();
654        assert_eq!(total_nodes, 3);
655        assert!(result.len() <= 3);
656    }
657
658    #[test]
659    fn cluster_two_cliques() {
660        let g = build_graph(
661            &["a1", "a2", "a3", "b1", "b2", "b3"],
662            &[
663                ("a1", "a2"),
664                ("a2", "a3"),
665                ("a1", "a3"),
666                ("b1", "b2"),
667                ("b2", "b3"),
668                ("b1", "b3"),
669                ("a3", "b1"), // bridge
670            ],
671        );
672        let result = cluster(&g);
673        let total_nodes: usize = result.values().map(|v| v.len()).sum();
674        assert_eq!(total_nodes, 6);
675        assert!(!result.is_empty());
676    }
677
678    #[test]
679    fn cohesion_score_single_node() {
680        let g = build_graph(&["a"], &[]);
681        let score = cohesion_score(&g, &["a".to_string()]);
682        assert!((score - 1.0).abs() < f64::EPSILON);
683    }
684
685    #[test]
686    fn cohesion_score_complete_graph() {
687        let g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
688        let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
689        assert!((score - 1.0).abs() < f64::EPSILON);
690    }
691
692    #[test]
693    fn cohesion_score_no_edges() {
694        let g = build_graph(&["a", "b", "c"], &[]);
695        let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
696        assert!((score - 0.0).abs() < f64::EPSILON);
697    }
698
699    #[test]
700    fn cohesion_score_partial() {
701        let g = build_graph(&["a", "b", "c"], &[("a", "b")]);
702        let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
703        assert!((score - 0.33).abs() < 0.01);
704    }
705
706    #[test]
707    fn score_all_works() {
708        let g = build_graph(&["a", "b"], &[("a", "b")]);
709        let mut communities = HashMap::new();
710        communities.insert(0, vec!["a".to_string(), "b".to_string()]);
711        let scores = score_all(&g, &communities);
712        assert_eq!(scores.len(), 1);
713        assert!((scores[&0] - 1.0).abs() < f64::EPSILON);
714    }
715
716    #[test]
717    fn cluster_graph_mutates_communities() {
718        let mut g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
719        let result = cluster_graph(&mut g);
720        assert!(!result.is_empty());
721        assert!(!g.communities.is_empty());
722    }
723
724    // ── Leiden-specific tests ──
725
726    #[test]
727    fn leiden_splits_disconnected_community() {
728        // Two disconnected cliques — Leiden should guarantee they end up
729        // in separate communities (Louvain might not).
730        let g = build_graph(
731            &["a1", "a2", "a3", "b1", "b2", "b3"],
732            &[
733                ("a1", "a2"),
734                ("a2", "a3"),
735                ("a1", "a3"),
736                ("b1", "b2"),
737                ("b2", "b3"),
738                ("b1", "b3"),
739                // No bridge between the two cliques
740            ],
741        );
742        let result = cluster(&g);
743        // Must have exactly 2 communities
744        assert_eq!(
745            result.len(),
746            2,
747            "disconnected cliques should form 2 communities"
748        );
749        // Each community should have 3 nodes
750        for nodes in result.values() {
751            assert_eq!(nodes.len(), 3);
752        }
753    }
754
755    #[test]
756    fn leiden_connected_components_within() {
757        let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
758        for id in &["a", "b", "c", "d"] {
759            adj.insert(id.to_string(), Vec::new());
760        }
761        // a-b connected, c-d connected, no bridge
762        adj.get_mut("a").unwrap().push(("b".into(), 1.0));
763        adj.get_mut("b").unwrap().push(("a".into(), 1.0));
764        adj.get_mut("c").unwrap().push(("d".into(), 1.0));
765        adj.get_mut("d").unwrap().push(("c".into(), 1.0));
766
767        let members: Vec<String> = vec!["a", "b", "c", "d"]
768            .into_iter()
769            .map(String::from)
770            .collect();
771        let components = connected_components_within(&adj, &members);
772        assert_eq!(components.len(), 2);
773    }
774
775    #[test]
776    fn leiden_single_component() {
777        let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
778        for id in &["a", "b", "c"] {
779            adj.insert(id.to_string(), Vec::new());
780        }
781        adj.get_mut("a").unwrap().push(("b".into(), 1.0));
782        adj.get_mut("b").unwrap().push(("a".into(), 1.0));
783        adj.get_mut("b").unwrap().push(("c".into(), 1.0));
784        adj.get_mut("c").unwrap().push(("b".into(), 1.0));
785
786        let members: Vec<String> = vec!["a", "b", "c"].into_iter().map(String::from).collect();
787        let components = connected_components_within(&adj, &members);
788        assert_eq!(components.len(), 1);
789    }
790}