1use std::collections::{HashMap, HashSet, VecDeque};
9
10use tracing::debug;
11
12use graphify_core::graph::KnowledgeGraph;
13use graphify_core::model::CommunityInfo;
14
15const MAX_COMMUNITY_FRACTION: f64 = 0.25;
18
19const MIN_SPLIT_SIZE: usize = 10;
21
22const MIN_COMMUNITY_SIZE: usize = 5;
25
26const RESOLUTION: f64 = 0.3;
29
30pub fn cluster(graph: &KnowledgeGraph) -> HashMap<usize, Vec<String>> {
40 let node_count = graph.node_count();
41 if node_count == 0 {
42 return HashMap::new();
43 }
44
45 if graph.edge_count() == 0 {
47 return graph
48 .node_ids()
49 .into_iter()
50 .enumerate()
51 .map(|(i, id)| (i, vec![id]))
52 .collect();
53 }
54
55 let partition = leiden_partition(graph);
56
57 let mut communities: HashMap<usize, Vec<String>> = HashMap::new();
59 for (node_id, cid) in &partition {
60 communities.entry(*cid).or_default().push(node_id.clone());
61 }
62
63 let adj = build_adjacency(graph);
65 merge_small_communities(&mut communities, &adj);
66
67 let max_size = std::cmp::max(
69 MIN_SPLIT_SIZE,
70 (node_count as f64 * MAX_COMMUNITY_FRACTION) as usize,
71 );
72 let mut final_communities: Vec<Vec<String>> = Vec::new();
73 for nodes in communities.values() {
74 if nodes.len() > max_size {
75 final_communities.extend(split_community(graph, nodes));
76 } else {
77 final_communities.push(nodes.clone());
78 }
79 }
80
81 final_communities.sort_by_key(|b| std::cmp::Reverse(b.len()));
83 final_communities
84 .into_iter()
85 .enumerate()
86 .map(|(i, mut nodes)| {
87 nodes.sort();
88 (i, nodes)
89 })
90 .collect()
91}
92
93pub fn cluster_graph(graph: &mut KnowledgeGraph) -> HashMap<usize, Vec<String>> {
95 let communities = cluster(graph);
96
97 let scores = score_all(graph, &communities);
99 let mut infos: Vec<CommunityInfo> = communities
100 .iter()
101 .map(|(&cid, nodes)| CommunityInfo {
102 id: cid,
103 nodes: nodes.clone(),
104 cohesion: scores.get(&cid).copied().unwrap_or(0.0),
105 label: None,
106 })
107 .collect();
108 infos.sort_by_key(|c| c.id);
109 graph.communities = infos;
110
111 communities
112}
113
114pub fn cohesion_score(graph: &KnowledgeGraph, community_nodes: &[String]) -> f64 {
118 let n = community_nodes.len();
119 if n <= 1 {
120 return 1.0;
121 }
122
123 let node_set: HashSet<&str> = community_nodes.iter().map(|s| s.as_str()).collect();
124 let mut actual_edges = 0usize;
125
126 for node_id in community_nodes {
128 let neighbors = graph.get_neighbors(node_id);
129 for neighbor in &neighbors {
130 if node_set.contains(neighbor.id.as_str()) {
131 actual_edges += 1;
132 }
133 }
134 }
135 actual_edges /= 2;
137
138 let possible = n * (n - 1) / 2;
139 if possible == 0 {
140 return 0.0;
141 }
142 ((actual_edges as f64 / possible as f64) * 100.0).round() / 100.0
143}
144
145pub fn score_all(
147 graph: &KnowledgeGraph,
148 communities: &HashMap<usize, Vec<String>>,
149) -> HashMap<usize, f64> {
150 communities
151 .iter()
152 .map(|(&cid, nodes)| (cid, cohesion_score(graph, nodes)))
153 .collect()
154}
155
156fn build_adjacency(graph: &KnowledgeGraph) -> HashMap<String, Vec<(String, f64)>> {
162 let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
163 for id in graph.node_ids() {
164 adj.entry(id).or_default();
165 }
166 for (src, tgt, edge) in graph.edges_with_endpoints() {
167 adj.entry(src.to_string())
168 .or_default()
169 .push((tgt.to_string(), edge.weight));
170 adj.entry(tgt.to_string())
171 .or_default()
172 .push((src.to_string(), edge.weight));
173 }
174 adj
175}
176
177fn total_weight(adj: &HashMap<String, Vec<(String, f64)>>) -> f64 {
179 let mut m = 0.0;
180 for neighbors in adj.values() {
181 for (_, w) in neighbors {
182 m += w;
183 }
184 }
185 m / 2.0 }
187
188fn node_strength(adj: &HashMap<String, Vec<(String, f64)>>, node: &str) -> f64 {
190 adj.get(node)
191 .map(|neighbors| neighbors.iter().map(|(_, w)| w).sum())
192 .unwrap_or(0.0)
193}
194
195fn edges_to_community(
197 adj: &HashMap<String, Vec<(String, f64)>>,
198 node: &str,
199 community: &HashSet<&str>,
200) -> f64 {
201 adj.get(node)
202 .map(|neighbors| {
203 neighbors
204 .iter()
205 .filter(|(n, _)| community.contains(n.as_str()))
206 .map(|(_, w)| w)
207 .sum()
208 })
209 .unwrap_or(0.0)
210}
211
212fn community_strength(adj: &HashMap<String, Vec<(String, f64)>>, members: &HashSet<&str>) -> f64 {
214 members.iter().map(|n| node_strength(adj, n)).sum()
215}
216
217fn leiden_partition(graph: &KnowledgeGraph) -> HashMap<String, usize> {
226 let adj = build_adjacency(graph);
227 let m = total_weight(&adj);
228 if m == 0.0 {
229 return graph
230 .node_ids()
231 .into_iter()
232 .enumerate()
233 .map(|(i, id)| (id, i))
234 .collect();
235 }
236
237 let node_ids = graph.node_ids();
238
239 let mut community_of: HashMap<String, usize> = node_ids
241 .iter()
242 .enumerate()
243 .map(|(i, id)| (id.clone(), i))
244 .collect();
245
246 let max_outer_iterations = 10;
247 for _outer in 0..max_outer_iterations {
248 let changed = louvain_phase(&adj, &node_ids, &mut community_of, m);
250
251 let refined = refinement_phase(&adj, &mut community_of, m);
253
254 if !changed && !refined {
255 break;
256 }
257 }
258
259 compact_ids(&mut community_of);
261 community_of
262}
263
264fn louvain_phase(
269 adj: &HashMap<String, Vec<(String, f64)>>,
270 node_ids: &[String],
271 community_of: &mut HashMap<String, usize>,
272 m: f64,
273) -> bool {
274 let mut community_members: HashMap<usize, HashSet<String>> = HashMap::new();
275 for (node, &cid) in community_of.iter() {
276 community_members
277 .entry(cid)
278 .or_default()
279 .insert(node.clone());
280 }
281
282 let ki_cache: HashMap<&str, f64> = adj
284 .keys()
285 .map(|n| (n.as_str(), node_strength(adj, n)))
286 .collect();
287
288 let mut sigma_c: HashMap<usize, f64> = HashMap::new();
290 for (&cid, members) in &community_members {
291 let sum: f64 = members
292 .iter()
293 .map(|n| ki_cache.get(n.as_str()).copied().unwrap_or(0.0))
294 .sum();
295 sigma_c.insert(cid, sum);
296 }
297
298 let max_iterations = 50;
299 let mut any_changed = false;
300
301 for _iteration in 0..max_iterations {
302 let mut improved = false;
303
304 for node in node_ids {
305 let current_community = community_of[node];
306 let ki = ki_cache.get(node.as_str()).copied().unwrap_or(0.0);
307
308 let mut ki_to: HashMap<usize, f64> = HashMap::new();
310 if let Some(neighbors) = adj.get(node.as_str()) {
311 for (nbr, w) in neighbors {
312 let nbr_cid = community_of[nbr];
313 *ki_to.entry(nbr_cid).or_default() += w;
314 }
315 }
316
317 let mut best_community = current_community;
318 let mut best_gain = 0.0f64;
319
320 let ki_in_current = ki_to.get(¤t_community).copied().unwrap_or(0.0);
322 let sigma_current = sigma_c.get(¤t_community).copied().unwrap_or(0.0) - ki;
323
324 for (&target_community, &ki_in_target) in &ki_to {
325 if target_community == current_community {
326 continue;
327 }
328
329 let sigma_target = sigma_c.get(&target_community).copied().unwrap_or(0.0);
330
331 let gain = (ki_in_target - ki_in_current) / m
332 - RESOLUTION * ki * (sigma_target - sigma_current) / (2.0 * m * m);
333
334 if gain > best_gain {
335 best_gain = gain;
336 best_community = target_community;
337 }
338 }
339
340 if best_community != current_community {
341 community_members
343 .get_mut(¤t_community)
344 .unwrap()
345 .remove(node);
346 community_members
347 .entry(best_community)
348 .or_default()
349 .insert(node.clone());
350
351 *sigma_c.entry(current_community).or_default() -= ki;
353 *sigma_c.entry(best_community).or_default() += ki;
354
355 community_of.insert(node.clone(), best_community);
356 improved = true;
357 any_changed = true;
358 }
359 }
360
361 if !improved {
362 break;
363 }
364 }
365
366 any_changed
367}
368
369fn refinement_phase(
377 adj: &HashMap<String, Vec<(String, f64)>>,
378 community_of: &mut HashMap<String, usize>,
379 m: f64,
380) -> bool {
381 let mut community_members: HashMap<usize, Vec<String>> = HashMap::new();
383 for (node, &cid) in community_of.iter() {
384 community_members.entry(cid).or_default().push(node.clone());
385 }
386
387 let mut any_refined = false;
388 let mut next_cid = community_members.keys().copied().max().unwrap_or(0) + 1;
389
390 let community_ids: Vec<usize> = community_members.keys().copied().collect();
391 for cid in community_ids {
392 let members = match community_members.get(&cid) {
393 Some(m) if m.len() > 1 => m.clone(),
394 _ => continue,
395 };
396
397 let components = connected_components_within(adj, &members);
399 if components.len() <= 1 {
400 continue; }
402
403 debug!(
404 "Leiden refinement: community {} has {} disconnected components, splitting",
405 cid,
406 components.len()
407 );
408
409 let mut sorted_components = components;
413 sorted_components.sort_by_key(|c| std::cmp::Reverse(c.len()));
414
415 for component in sorted_components.iter().skip(1) {
417 let mut neighbor_cid_edges: HashMap<usize, f64> = HashMap::new();
419 for node in component {
420 if let Some(neighbors) = adj.get(node.as_str()) {
421 for (nbr, w) in neighbors {
422 let nbr_cid = community_of[nbr];
423 if nbr_cid != cid {
424 *neighbor_cid_edges.entry(nbr_cid).or_default() += w;
425 }
426 }
427 }
428 }
429
430 let target_cid = if let Some((&best_cid, _)) = neighbor_cid_edges
433 .iter()
434 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
435 {
436 let _component_set: HashSet<&str> = component.iter().map(|s| s.as_str()).collect();
438 let target_members: HashSet<&str> = community_members
439 .get(&best_cid)
440 .map(|s| s.iter().map(|x| x.as_str()).collect())
441 .unwrap_or_default();
442
443 let ki_sum: f64 = component.iter().map(|n| node_strength(adj, n)).sum();
444 let ki_in = component
445 .iter()
446 .map(|n| edges_to_community(adj, n, &target_members))
447 .sum::<f64>();
448 let sigma_t = community_strength(adj, &target_members);
449
450 let gain = ki_in / m - ki_sum * sigma_t / (2.0 * m * m);
451 if gain > 0.0 {
452 best_cid
453 } else {
454 let new_cid = next_cid;
455 next_cid += 1;
456 new_cid
457 }
458 } else {
459 let new_cid = next_cid;
460 next_cid += 1;
461 new_cid
462 };
463
464 for node in component {
466 community_of.insert(node.clone(), target_cid);
467 community_members
468 .entry(target_cid)
469 .or_default()
470 .push(node.clone());
471 }
472 any_refined = true;
473 }
474
475 if any_refined {
477 community_members.insert(cid, sorted_components.into_iter().next().unwrap());
478 }
479 }
480
481 any_refined
482}
483
484fn connected_components_within(
486 adj: &HashMap<String, Vec<(String, f64)>>,
487 members: &[String],
488) -> Vec<Vec<String>> {
489 let member_set: HashSet<&str> = members.iter().map(|s| s.as_str()).collect();
490 let mut visited: HashSet<&str> = HashSet::new();
491 let mut components: Vec<Vec<String>> = Vec::new();
492
493 for node in members {
494 if visited.contains(node.as_str()) {
495 continue;
496 }
497
498 let mut component = Vec::new();
499 let mut queue = VecDeque::new();
500 queue.push_back(node.as_str());
501 visited.insert(node.as_str());
502
503 while let Some(current) = queue.pop_front() {
504 component.push(current.to_string());
505 if let Some(neighbors) = adj.get(current) {
506 for (nbr, _) in neighbors {
507 if member_set.contains(nbr.as_str()) && !visited.contains(nbr.as_str()) {
508 visited.insert(nbr.as_str());
509 queue.push_back(nbr.as_str());
510 }
511 }
512 }
513 }
514
515 components.push(component);
516 }
517
518 components
519}
520
521fn compact_ids(community_of: &mut HashMap<String, usize>) {
523 let mut used: Vec<usize> = community_of
524 .values()
525 .copied()
526 .collect::<HashSet<_>>()
527 .into_iter()
528 .collect();
529 used.sort();
530 let remap: HashMap<usize, usize> = used
531 .iter()
532 .enumerate()
533 .map(|(new_id, &old_id)| (old_id, new_id))
534 .collect();
535 for cid in community_of.values_mut() {
536 *cid = remap[cid];
537 }
538}
539
540fn merge_small_communities(
543 communities: &mut HashMap<usize, Vec<String>>,
544 adj: &HashMap<String, Vec<(String, f64)>>,
545) {
546 let mut node_to_cid: HashMap<String, usize> = communities
548 .iter()
549 .flat_map(|(&cid, nodes)| nodes.iter().map(move |n| (n.clone(), cid)))
550 .collect();
551
552 loop {
553 let merge = communities
555 .iter()
556 .filter(|(_, nodes)| nodes.len() < MIN_COMMUNITY_SIZE)
557 .find_map(|(&small_cid, nodes)| {
558 let mut neighbor_edges: HashMap<usize, f64> = HashMap::new();
560 for node in nodes {
561 if let Some(neighbors) = adj.get(node.as_str()) {
562 for (neighbor, weight) in neighbors {
563 if let Some(&ncid) = node_to_cid.get(neighbor.as_str())
564 && ncid != small_cid
565 {
566 *neighbor_edges.entry(ncid).or_default() += weight;
567 }
568 }
569 }
570 }
571 neighbor_edges
573 .iter()
574 .max_by(|a, b| a.1.total_cmp(b.1))
575 .map(|(&best_cid, _)| (small_cid, best_cid))
576 });
577
578 match merge {
579 Some((small_cid, best_cid)) => {
580 let nodes = communities.remove(&small_cid).unwrap_or_default();
581 for node in &nodes {
583 node_to_cid.insert(node.clone(), best_cid);
584 }
585 communities.entry(best_cid).or_default().extend(nodes);
586 }
587 None => break, }
589 }
590}
591
592fn split_community(graph: &KnowledgeGraph, nodes: &[String]) -> Vec<Vec<String>> {
594 if nodes.len() < MIN_SPLIT_SIZE {
595 return vec![nodes.to_vec()];
596 }
597
598 let node_set: HashSet<&str> = nodes.iter().map(|s| s.as_str()).collect();
599
600 let mut sub_adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
602 for node in nodes {
603 sub_adj.entry(node.clone()).or_default();
604 }
605 for (src, tgt, edge) in graph.edges_with_endpoints() {
606 if node_set.contains(src) && node_set.contains(tgt) {
607 sub_adj
608 .entry(src.to_string())
609 .or_default()
610 .push((tgt.to_string(), edge.weight));
611 sub_adj
612 .entry(tgt.to_string())
613 .or_default()
614 .push((src.to_string(), edge.weight));
615 }
616 }
617
618 let m = total_weight(&sub_adj);
619 if m == 0.0 {
620 return nodes.iter().map(|n| vec![n.clone()]).collect();
621 }
622
623 let mut community_of: HashMap<String, usize> = nodes
625 .iter()
626 .enumerate()
627 .map(|(i, id)| (id.clone(), i))
628 .collect();
629
630 let node_list: Vec<String> = nodes.to_vec();
631 for _ in 0..5 {
632 let changed = louvain_phase(&sub_adj, &node_list, &mut community_of, m);
633 let refined = refinement_phase(&sub_adj, &mut community_of, m);
634 if !changed && !refined {
635 break;
636 }
637 }
638
639 let mut groups: HashMap<usize, Vec<String>> = HashMap::new();
641 for (node, cid) in &community_of {
642 groups.entry(*cid).or_default().push(node.clone());
643 }
644
645 let result: Vec<Vec<String>> = groups.into_values().filter(|s| !s.is_empty()).collect();
646
647 if result.len() <= 1 {
648 debug!("could not split community of {} nodes further", nodes.len());
649 return vec![nodes.to_vec()];
650 }
651
652 result
653}
654
655pub fn cluster_incremental(
669 graph: &KnowledgeGraph,
670 prev_communities: &HashMap<usize, Vec<String>>,
671 changed_files: &[String],
672) -> HashMap<usize, Vec<String>> {
673 if prev_communities.is_empty() || changed_files.is_empty() {
674 return cluster(graph);
675 }
676
677 let changed_set: HashSet<&str> = changed_files.iter().map(|s| s.as_str()).collect();
678
679 let affected_nodes: HashSet<String> = graph
681 .nodes()
682 .iter()
683 .filter(|n| changed_set.contains(n.source_file.as_str()))
684 .map(|n| n.id.clone())
685 .collect();
686
687 if affected_nodes.is_empty() {
688 return prev_communities.clone();
689 }
690
691 let node_to_cid: HashMap<&str, usize> = prev_communities
693 .iter()
694 .flat_map(|(&cid, nodes)| nodes.iter().map(move |n| (n.as_str(), cid)))
695 .collect();
696
697 let mut affected_cids: HashSet<usize> = HashSet::new();
699 for node_id in &affected_nodes {
700 if let Some(&cid) = node_to_cid.get(node_id.as_str()) {
701 affected_cids.insert(cid);
702 }
703 for neighbor in graph.get_neighbors(node_id) {
705 if let Some(&cid) = node_to_cid.get(neighbor.id.as_str()) {
706 affected_cids.insert(cid);
707 }
708 }
709 }
710
711 if affected_cids.len() * 2 > prev_communities.len() {
713 debug!(
714 "incremental: {}% communities affected, falling back to full cluster",
715 affected_cids.len() * 100 / prev_communities.len().max(1)
716 );
717 return cluster(graph);
718 }
719
720 debug!(
721 "incremental: re-clustering {} of {} communities ({} affected nodes)",
722 affected_cids.len(),
723 prev_communities.len(),
724 affected_nodes.len()
725 );
726
727 let affected_community_nodes: Vec<String> = affected_cids
729 .iter()
730 .flat_map(|cid| prev_communities.get(cid).cloned().unwrap_or_default())
731 .collect();
732
733 let all_prev_nodes: HashSet<&str> = prev_communities
735 .values()
736 .flat_map(|v| v.iter().map(|s| s.as_str()))
737 .collect();
738 let new_nodes: Vec<String> = graph
739 .node_ids()
740 .into_iter()
741 .filter(|id| !all_prev_nodes.contains(id.as_str()))
742 .collect();
743
744 let mut recluster_nodes: Vec<String> = affected_community_nodes;
745 recluster_nodes.extend(new_nodes);
746
747 let sub_communities = split_community(graph, &recluster_nodes);
749
750 let mut result: HashMap<usize, Vec<String>> = HashMap::new();
752 let mut next_cid = 0usize;
753
754 for (&cid, nodes) in prev_communities {
756 if !affected_cids.contains(&cid) {
757 result.insert(next_cid, nodes.clone());
758 next_cid += 1;
759 }
760 }
761
762 for nodes in sub_communities {
764 if !nodes.is_empty() {
765 result.insert(next_cid, nodes);
766 next_cid += 1;
767 }
768 }
769
770 let mut final_vec: Vec<Vec<String>> = result.into_values().collect();
772 final_vec.sort_by_key(|b| std::cmp::Reverse(b.len()));
773 final_vec
774 .into_iter()
775 .enumerate()
776 .map(|(i, mut nodes)| {
777 nodes.sort();
778 (i, nodes)
779 })
780 .collect()
781}
782
783#[cfg(test)]
788mod tests {
789 use super::*;
790 use graphify_core::confidence::Confidence;
791 use graphify_core::graph::KnowledgeGraph;
792 use graphify_core::model::{GraphEdge, GraphNode, NodeType};
793 use std::collections::HashMap as StdMap;
794
795 fn make_node(id: &str) -> GraphNode {
796 GraphNode {
797 id: id.into(),
798 label: id.into(),
799 source_file: "test.rs".into(),
800 source_location: None,
801 node_type: NodeType::Class,
802 community: None,
803 extra: StdMap::new(),
804 }
805 }
806
807 fn make_edge(src: &str, tgt: &str) -> GraphEdge {
808 GraphEdge {
809 source: src.into(),
810 target: tgt.into(),
811 relation: "calls".into(),
812 confidence: Confidence::Extracted,
813 confidence_score: 1.0,
814 source_file: "test.rs".into(),
815 source_location: None,
816 weight: 1.0,
817 extra: StdMap::new(),
818 }
819 }
820
821 fn build_graph(nodes: &[&str], edges: &[(&str, &str)]) -> KnowledgeGraph {
822 let mut g = KnowledgeGraph::new();
823 for &id in nodes {
824 g.add_node(make_node(id)).unwrap();
825 }
826 for &(s, t) in edges {
827 g.add_edge(make_edge(s, t)).unwrap();
828 }
829 g
830 }
831
832 #[test]
833 fn cluster_empty_graph() {
834 let g = KnowledgeGraph::new();
835 let result = cluster(&g);
836 assert!(result.is_empty());
837 }
838
839 #[test]
840 fn cluster_no_edges() {
841 let g = build_graph(&["a", "b", "c"], &[]);
842 let result = cluster(&g);
843 assert_eq!(result.len(), 3);
844 for nodes in result.values() {
845 assert_eq!(nodes.len(), 1);
846 }
847 }
848
849 #[test]
850 fn cluster_single_clique() {
851 let g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
852 let result = cluster(&g);
853 let total_nodes: usize = result.values().map(|v| v.len()).sum();
854 assert_eq!(total_nodes, 3);
855 assert!(result.len() <= 3);
856 }
857
858 #[test]
859 fn cluster_two_cliques() {
860 let g = build_graph(
861 &["a1", "a2", "a3", "b1", "b2", "b3"],
862 &[
863 ("a1", "a2"),
864 ("a2", "a3"),
865 ("a1", "a3"),
866 ("b1", "b2"),
867 ("b2", "b3"),
868 ("b1", "b3"),
869 ("a3", "b1"), ],
871 );
872 let result = cluster(&g);
873 let total_nodes: usize = result.values().map(|v| v.len()).sum();
874 assert_eq!(total_nodes, 6);
875 assert!(!result.is_empty());
876 }
877
878 #[test]
879 fn cohesion_score_single_node() {
880 let g = build_graph(&["a"], &[]);
881 let score = cohesion_score(&g, &["a".to_string()]);
882 assert!((score - 1.0).abs() < f64::EPSILON);
883 }
884
885 #[test]
886 fn cohesion_score_complete_graph() {
887 let g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
888 let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
889 assert!((score - 1.0).abs() < f64::EPSILON);
890 }
891
892 #[test]
893 fn cohesion_score_no_edges() {
894 let g = build_graph(&["a", "b", "c"], &[]);
895 let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
896 assert!((score - 0.0).abs() < f64::EPSILON);
897 }
898
899 #[test]
900 fn cohesion_score_partial() {
901 let g = build_graph(&["a", "b", "c"], &[("a", "b")]);
902 let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
903 assert!((score - 0.33).abs() < 0.01);
904 }
905
906 #[test]
907 fn score_all_works() {
908 let g = build_graph(&["a", "b"], &[("a", "b")]);
909 let mut communities = HashMap::new();
910 communities.insert(0, vec!["a".to_string(), "b".to_string()]);
911 let scores = score_all(&g, &communities);
912 assert_eq!(scores.len(), 1);
913 assert!((scores[&0] - 1.0).abs() < f64::EPSILON);
914 }
915
916 #[test]
917 fn cluster_graph_mutates_communities() {
918 let mut g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
919 let result = cluster_graph(&mut g);
920 assert!(!result.is_empty());
921 assert!(!g.communities.is_empty());
922 }
923
924 #[test]
927 fn leiden_splits_disconnected_community() {
928 let g = build_graph(
931 &["a1", "a2", "a3", "b1", "b2", "b3"],
932 &[
933 ("a1", "a2"),
934 ("a2", "a3"),
935 ("a1", "a3"),
936 ("b1", "b2"),
937 ("b2", "b3"),
938 ("b1", "b3"),
939 ],
941 );
942 let result = cluster(&g);
943 assert_eq!(
945 result.len(),
946 2,
947 "disconnected cliques should form 2 communities"
948 );
949 for nodes in result.values() {
951 assert_eq!(nodes.len(), 3);
952 }
953 }
954
955 #[test]
956 fn leiden_connected_components_within() {
957 let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
958 for id in &["a", "b", "c", "d"] {
959 adj.insert(id.to_string(), Vec::new());
960 }
961 adj.get_mut("a").unwrap().push(("b".into(), 1.0));
963 adj.get_mut("b").unwrap().push(("a".into(), 1.0));
964 adj.get_mut("c").unwrap().push(("d".into(), 1.0));
965 adj.get_mut("d").unwrap().push(("c".into(), 1.0));
966
967 let members: Vec<String> = vec!["a", "b", "c", "d"]
968 .into_iter()
969 .map(String::from)
970 .collect();
971 let components = connected_components_within(&adj, &members);
972 assert_eq!(components.len(), 2);
973 }
974
975 #[test]
976 fn leiden_single_component() {
977 let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
978 for id in &["a", "b", "c"] {
979 adj.insert(id.to_string(), Vec::new());
980 }
981 adj.get_mut("a").unwrap().push(("b".into(), 1.0));
982 adj.get_mut("b").unwrap().push(("a".into(), 1.0));
983 adj.get_mut("b").unwrap().push(("c".into(), 1.0));
984 adj.get_mut("c").unwrap().push(("b".into(), 1.0));
985
986 let members: Vec<String> = vec!["a", "b", "c"].into_iter().map(String::from).collect();
987 let components = connected_components_within(&adj, &members);
988 assert_eq!(components.len(), 1);
989 }
990}