1use std::collections::{HashMap, HashSet, VecDeque};
9
10use tracing::debug;
11
12use graphify_core::graph::KnowledgeGraph;
13use graphify_core::model::CommunityInfo;
14
15const MAX_COMMUNITY_FRACTION: f64 = 0.25;
18
19const MIN_SPLIT_SIZE: usize = 10;
21
22const MIN_COMMUNITY_SIZE: usize = 5;
25
26const RESOLUTION: f64 = 0.3;
29
30pub fn cluster(graph: &KnowledgeGraph) -> HashMap<usize, Vec<String>> {
40 let node_count = graph.node_count();
41 if node_count == 0 {
42 return HashMap::new();
43 }
44
45 if graph.edge_count() == 0 {
47 return graph
48 .node_ids()
49 .into_iter()
50 .enumerate()
51 .map(|(i, id)| (i, vec![id]))
52 .collect();
53 }
54
55 let partition = leiden_partition(graph);
56
57 let mut communities: HashMap<usize, Vec<String>> = HashMap::new();
59 for (node_id, cid) in &partition {
60 communities.entry(*cid).or_default().push(node_id.clone());
61 }
62
63 let adj = build_adjacency(graph);
65 merge_small_communities(&mut communities, &adj);
66
67 let max_size = std::cmp::max(
69 MIN_SPLIT_SIZE,
70 (node_count as f64 * MAX_COMMUNITY_FRACTION) as usize,
71 );
72 let mut final_communities: Vec<Vec<String>> = Vec::new();
73 for nodes in communities.values() {
74 if nodes.len() > max_size {
75 final_communities.extend(split_community(graph, nodes));
76 } else {
77 final_communities.push(nodes.clone());
78 }
79 }
80
81 final_communities.sort_by_key(|b| std::cmp::Reverse(b.len()));
83 final_communities
84 .into_iter()
85 .enumerate()
86 .map(|(i, mut nodes)| {
87 nodes.sort();
88 (i, nodes)
89 })
90 .collect()
91}
92
93pub fn cluster_graph(graph: &mut KnowledgeGraph) -> HashMap<usize, Vec<String>> {
95 let communities = cluster(graph);
96
97 let scores = score_all(graph, &communities);
99 let mut infos: Vec<CommunityInfo> = communities
100 .iter()
101 .map(|(&cid, nodes)| CommunityInfo {
102 id: cid,
103 nodes: nodes.clone(),
104 cohesion: scores.get(&cid).copied().unwrap_or(0.0),
105 label: None,
106 })
107 .collect();
108 infos.sort_by_key(|c| c.id);
109 graph.communities = infos;
110
111 communities
112}
113
114pub fn cohesion_score(graph: &KnowledgeGraph, community_nodes: &[String]) -> f64 {
118 let n = community_nodes.len();
119 if n <= 1 {
120 return 1.0;
121 }
122
123 let node_set: HashSet<&str> = community_nodes.iter().map(|s| s.as_str()).collect();
124 let mut actual_edges = 0usize;
125
126 for node_id in community_nodes {
128 let neighbors = graph.get_neighbors(node_id);
129 for neighbor in &neighbors {
130 if node_set.contains(neighbor.id.as_str()) {
131 actual_edges += 1;
132 }
133 }
134 }
135 actual_edges /= 2;
137
138 let possible = n * (n - 1) / 2;
139 if possible == 0 {
140 return 0.0;
141 }
142 ((actual_edges as f64 / possible as f64) * 100.0).round() / 100.0
143}
144
145pub fn score_all(
147 graph: &KnowledgeGraph,
148 communities: &HashMap<usize, Vec<String>>,
149) -> HashMap<usize, f64> {
150 communities
151 .iter()
152 .map(|(&cid, nodes)| (cid, cohesion_score(graph, nodes)))
153 .collect()
154}
155
156fn build_adjacency(graph: &KnowledgeGraph) -> HashMap<String, Vec<(String, f64)>> {
162 let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
163 for id in graph.node_ids() {
164 adj.entry(id).or_default();
165 }
166 for (src, tgt, edge) in graph.edges_with_endpoints() {
167 adj.entry(src.to_string())
168 .or_default()
169 .push((tgt.to_string(), edge.weight));
170 adj.entry(tgt.to_string())
171 .or_default()
172 .push((src.to_string(), edge.weight));
173 }
174 adj
175}
176
177fn total_weight(adj: &HashMap<String, Vec<(String, f64)>>) -> f64 {
179 let mut m = 0.0;
180 for neighbors in adj.values() {
181 for (_, w) in neighbors {
182 m += w;
183 }
184 }
185 m / 2.0 }
187
188fn node_strength(adj: &HashMap<String, Vec<(String, f64)>>, node: &str) -> f64 {
190 adj.get(node)
191 .map(|neighbors| neighbors.iter().map(|(_, w)| w).sum())
192 .unwrap_or(0.0)
193}
194
195fn edges_to_community(
197 adj: &HashMap<String, Vec<(String, f64)>>,
198 node: &str,
199 community: &HashSet<&str>,
200) -> f64 {
201 adj.get(node)
202 .map(|neighbors| {
203 neighbors
204 .iter()
205 .filter(|(n, _)| community.contains(n.as_str()))
206 .map(|(_, w)| w)
207 .sum()
208 })
209 .unwrap_or(0.0)
210}
211
212fn community_strength(adj: &HashMap<String, Vec<(String, f64)>>, members: &HashSet<&str>) -> f64 {
214 members.iter().map(|n| node_strength(adj, n)).sum()
215}
216
217fn leiden_partition(graph: &KnowledgeGraph) -> HashMap<String, usize> {
226 let adj = build_adjacency(graph);
227 let m = total_weight(&adj);
228 if m == 0.0 {
229 return graph
230 .node_ids()
231 .into_iter()
232 .enumerate()
233 .map(|(i, id)| (id, i))
234 .collect();
235 }
236
237 let node_ids = graph.node_ids();
238
239 let mut community_of: HashMap<String, usize> = node_ids
241 .iter()
242 .enumerate()
243 .map(|(i, id)| (id.clone(), i))
244 .collect();
245
246 let max_outer_iterations = 10;
247 for _outer in 0..max_outer_iterations {
248 let changed = louvain_phase(&adj, &node_ids, &mut community_of, m);
250
251 let refined = refinement_phase(&adj, &mut community_of, m);
253
254 if !changed && !refined {
255 break;
256 }
257 }
258
259 compact_ids(&mut community_of);
261 community_of
262}
263
264fn louvain_phase(
269 adj: &HashMap<String, Vec<(String, f64)>>,
270 node_ids: &[String],
271 community_of: &mut HashMap<String, usize>,
272 m: f64,
273) -> bool {
274 let mut community_members: HashMap<usize, HashSet<String>> = HashMap::new();
275 for (node, &cid) in community_of.iter() {
276 community_members
277 .entry(cid)
278 .or_default()
279 .insert(node.clone());
280 }
281
282 let max_iterations = 50;
283 let mut any_changed = false;
284
285 for _iteration in 0..max_iterations {
286 let mut improved = false;
287
288 for node in node_ids {
289 let current_community = community_of[node];
290 let ki = node_strength(adj, node);
291
292 let mut neighbor_communities: HashSet<usize> = HashSet::new();
294 if let Some(neighbors) = adj.get(node.as_str()) {
295 for (n, _) in neighbors {
296 neighbor_communities.insert(community_of[n]);
297 }
298 }
299 neighbor_communities.insert(current_community);
300
301 let mut best_community = current_community;
302 let mut best_gain = 0.0f64;
303
304 for &target_community in &neighbor_communities {
305 if target_community == current_community {
306 continue;
307 }
308
309 let members_ref: HashSet<&str> = community_members
310 .get(&target_community)
311 .map(|s| s.iter().map(|x| x.as_str()).collect())
312 .unwrap_or_default();
313
314 let current_members_ref: HashSet<&str> = community_members
315 .get(¤t_community)
316 .map(|s| {
317 s.iter()
318 .filter(|x| x.as_str() != node.as_str())
319 .map(|x| x.as_str())
320 .collect()
321 })
322 .unwrap_or_default();
323
324 let ki_in_target = edges_to_community(adj, node, &members_ref);
325 let ki_in_current = edges_to_community(adj, node, ¤t_members_ref);
326 let sigma_target = community_strength(adj, &members_ref);
327 let sigma_current = community_strength(adj, ¤t_members_ref);
328
329 let gain = (ki_in_target - ki_in_current) / m
330 - RESOLUTION * ki * (sigma_target - sigma_current) / (2.0 * m * m);
331
332 if gain > best_gain {
333 best_gain = gain;
334 best_community = target_community;
335 }
336 }
337
338 if best_community != current_community {
339 community_members
340 .get_mut(¤t_community)
341 .unwrap()
342 .remove(node);
343 community_members
344 .entry(best_community)
345 .or_default()
346 .insert(node.clone());
347 community_of.insert(node.clone(), best_community);
348 improved = true;
349 any_changed = true;
350 }
351 }
352
353 if !improved {
354 break;
355 }
356 }
357
358 any_changed
359}
360
361fn refinement_phase(
369 adj: &HashMap<String, Vec<(String, f64)>>,
370 community_of: &mut HashMap<String, usize>,
371 m: f64,
372) -> bool {
373 let mut community_members: HashMap<usize, Vec<String>> = HashMap::new();
375 for (node, &cid) in community_of.iter() {
376 community_members.entry(cid).or_default().push(node.clone());
377 }
378
379 let mut any_refined = false;
380 let mut next_cid = community_members.keys().copied().max().unwrap_or(0) + 1;
381
382 let community_ids: Vec<usize> = community_members.keys().copied().collect();
383 for cid in community_ids {
384 let members = match community_members.get(&cid) {
385 Some(m) if m.len() > 1 => m.clone(),
386 _ => continue,
387 };
388
389 let components = connected_components_within(adj, &members);
391 if components.len() <= 1 {
392 continue; }
394
395 debug!(
396 "Leiden refinement: community {} has {} disconnected components, splitting",
397 cid,
398 components.len()
399 );
400
401 let mut sorted_components = components;
405 sorted_components.sort_by_key(|c| std::cmp::Reverse(c.len()));
406
407 for component in sorted_components.iter().skip(1) {
409 let mut neighbor_cid_edges: HashMap<usize, f64> = HashMap::new();
411 for node in component {
412 if let Some(neighbors) = adj.get(node.as_str()) {
413 for (nbr, w) in neighbors {
414 let nbr_cid = community_of[nbr];
415 if nbr_cid != cid {
416 *neighbor_cid_edges.entry(nbr_cid).or_default() += w;
417 }
418 }
419 }
420 }
421
422 let target_cid = if let Some((&best_cid, _)) = neighbor_cid_edges
425 .iter()
426 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
427 {
428 let _component_set: HashSet<&str> = component.iter().map(|s| s.as_str()).collect();
430 let target_members: HashSet<&str> = community_members
431 .get(&best_cid)
432 .map(|s| s.iter().map(|x| x.as_str()).collect())
433 .unwrap_or_default();
434
435 let ki_sum: f64 = component.iter().map(|n| node_strength(adj, n)).sum();
436 let ki_in = component
437 .iter()
438 .map(|n| edges_to_community(adj, n, &target_members))
439 .sum::<f64>();
440 let sigma_t = community_strength(adj, &target_members);
441
442 let gain = ki_in / m - ki_sum * sigma_t / (2.0 * m * m);
443 if gain > 0.0 {
444 best_cid
445 } else {
446 let new_cid = next_cid;
447 next_cid += 1;
448 new_cid
449 }
450 } else {
451 let new_cid = next_cid;
452 next_cid += 1;
453 new_cid
454 };
455
456 for node in component {
458 community_of.insert(node.clone(), target_cid);
459 community_members
460 .entry(target_cid)
461 .or_default()
462 .push(node.clone());
463 }
464 any_refined = true;
465 }
466
467 if any_refined {
469 community_members.insert(cid, sorted_components.into_iter().next().unwrap());
470 }
471 }
472
473 any_refined
474}
475
476fn connected_components_within(
478 adj: &HashMap<String, Vec<(String, f64)>>,
479 members: &[String],
480) -> Vec<Vec<String>> {
481 let member_set: HashSet<&str> = members.iter().map(|s| s.as_str()).collect();
482 let mut visited: HashSet<&str> = HashSet::new();
483 let mut components: Vec<Vec<String>> = Vec::new();
484
485 for node in members {
486 if visited.contains(node.as_str()) {
487 continue;
488 }
489
490 let mut component = Vec::new();
491 let mut queue = VecDeque::new();
492 queue.push_back(node.as_str());
493 visited.insert(node.as_str());
494
495 while let Some(current) = queue.pop_front() {
496 component.push(current.to_string());
497 if let Some(neighbors) = adj.get(current) {
498 for (nbr, _) in neighbors {
499 if member_set.contains(nbr.as_str()) && !visited.contains(nbr.as_str()) {
500 visited.insert(nbr.as_str());
501 queue.push_back(nbr.as_str());
502 }
503 }
504 }
505 }
506
507 components.push(component);
508 }
509
510 components
511}
512
513fn compact_ids(community_of: &mut HashMap<String, usize>) {
515 let mut used: Vec<usize> = community_of
516 .values()
517 .copied()
518 .collect::<HashSet<_>>()
519 .into_iter()
520 .collect();
521 used.sort();
522 let remap: HashMap<usize, usize> = used
523 .iter()
524 .enumerate()
525 .map(|(new_id, &old_id)| (old_id, new_id))
526 .collect();
527 for cid in community_of.values_mut() {
528 *cid = remap[cid];
529 }
530}
531
532fn merge_small_communities(
535 communities: &mut HashMap<usize, Vec<String>>,
536 adj: &HashMap<String, Vec<(String, f64)>>,
537) {
538 loop {
539 let node_to_cid: HashMap<String, usize> = communities
541 .iter()
542 .flat_map(|(&cid, nodes)| nodes.iter().map(move |n| (n.clone(), cid)))
543 .collect();
544
545 let merge = communities
547 .iter()
548 .filter(|(_, nodes)| nodes.len() < MIN_COMMUNITY_SIZE)
549 .find_map(|(&small_cid, nodes)| {
550 let mut neighbor_edges: HashMap<usize, f64> = HashMap::new();
552 for node in nodes {
553 if let Some(neighbors) = adj.get(node.as_str()) {
554 for (neighbor, weight) in neighbors {
555 if let Some(&ncid) = node_to_cid.get(neighbor.as_str())
556 && ncid != small_cid
557 {
558 *neighbor_edges.entry(ncid).or_default() += weight;
559 }
560 }
561 }
562 }
563 neighbor_edges
565 .iter()
566 .max_by(|a, b| a.1.total_cmp(b.1))
567 .map(|(&best_cid, _)| (small_cid, best_cid))
568 });
569
570 match merge {
571 Some((small_cid, best_cid)) => {
572 let nodes = communities.remove(&small_cid).unwrap_or_default();
573 communities.entry(best_cid).or_default().extend(nodes);
574 }
575 None => break, }
577 }
578}
579
580fn split_community(graph: &KnowledgeGraph, nodes: &[String]) -> Vec<Vec<String>> {
582 if nodes.len() < MIN_SPLIT_SIZE {
583 return vec![nodes.to_vec()];
584 }
585
586 let node_set: HashSet<&str> = nodes.iter().map(|s| s.as_str()).collect();
587
588 let mut sub_adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
590 for node in nodes {
591 sub_adj.entry(node.clone()).or_default();
592 }
593 for (src, tgt, edge) in graph.edges_with_endpoints() {
594 if node_set.contains(src) && node_set.contains(tgt) {
595 sub_adj
596 .entry(src.to_string())
597 .or_default()
598 .push((tgt.to_string(), edge.weight));
599 sub_adj
600 .entry(tgt.to_string())
601 .or_default()
602 .push((src.to_string(), edge.weight));
603 }
604 }
605
606 let m = total_weight(&sub_adj);
607 if m == 0.0 {
608 return nodes.iter().map(|n| vec![n.clone()]).collect();
609 }
610
611 let mut community_of: HashMap<String, usize> = nodes
613 .iter()
614 .enumerate()
615 .map(|(i, id)| (id.clone(), i))
616 .collect();
617
618 let node_list: Vec<String> = nodes.to_vec();
619 for _ in 0..5 {
620 let changed = louvain_phase(&sub_adj, &node_list, &mut community_of, m);
621 let refined = refinement_phase(&sub_adj, &mut community_of, m);
622 if !changed && !refined {
623 break;
624 }
625 }
626
627 let mut groups: HashMap<usize, Vec<String>> = HashMap::new();
629 for (node, cid) in &community_of {
630 groups.entry(*cid).or_default().push(node.clone());
631 }
632
633 let result: Vec<Vec<String>> = groups.into_values().filter(|s| !s.is_empty()).collect();
634
635 if result.len() <= 1 {
636 debug!("could not split community of {} nodes further", nodes.len());
637 return vec![nodes.to_vec()];
638 }
639
640 result
641}
642
643#[cfg(test)]
648mod tests {
649 use super::*;
650 use graphify_core::confidence::Confidence;
651 use graphify_core::graph::KnowledgeGraph;
652 use graphify_core::model::{GraphEdge, GraphNode, NodeType};
653 use std::collections::HashMap as StdMap;
654
655 fn make_node(id: &str) -> GraphNode {
656 GraphNode {
657 id: id.into(),
658 label: id.into(),
659 source_file: "test.rs".into(),
660 source_location: None,
661 node_type: NodeType::Class,
662 community: None,
663 extra: StdMap::new(),
664 }
665 }
666
667 fn make_edge(src: &str, tgt: &str) -> GraphEdge {
668 GraphEdge {
669 source: src.into(),
670 target: tgt.into(),
671 relation: "calls".into(),
672 confidence: Confidence::Extracted,
673 confidence_score: 1.0,
674 source_file: "test.rs".into(),
675 source_location: None,
676 weight: 1.0,
677 extra: StdMap::new(),
678 }
679 }
680
681 fn build_graph(nodes: &[&str], edges: &[(&str, &str)]) -> KnowledgeGraph {
682 let mut g = KnowledgeGraph::new();
683 for &id in nodes {
684 g.add_node(make_node(id)).unwrap();
685 }
686 for &(s, t) in edges {
687 g.add_edge(make_edge(s, t)).unwrap();
688 }
689 g
690 }
691
692 #[test]
693 fn cluster_empty_graph() {
694 let g = KnowledgeGraph::new();
695 let result = cluster(&g);
696 assert!(result.is_empty());
697 }
698
699 #[test]
700 fn cluster_no_edges() {
701 let g = build_graph(&["a", "b", "c"], &[]);
702 let result = cluster(&g);
703 assert_eq!(result.len(), 3);
704 for nodes in result.values() {
705 assert_eq!(nodes.len(), 1);
706 }
707 }
708
709 #[test]
710 fn cluster_single_clique() {
711 let g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
712 let result = cluster(&g);
713 let total_nodes: usize = result.values().map(|v| v.len()).sum();
714 assert_eq!(total_nodes, 3);
715 assert!(result.len() <= 3);
716 }
717
718 #[test]
719 fn cluster_two_cliques() {
720 let g = build_graph(
721 &["a1", "a2", "a3", "b1", "b2", "b3"],
722 &[
723 ("a1", "a2"),
724 ("a2", "a3"),
725 ("a1", "a3"),
726 ("b1", "b2"),
727 ("b2", "b3"),
728 ("b1", "b3"),
729 ("a3", "b1"), ],
731 );
732 let result = cluster(&g);
733 let total_nodes: usize = result.values().map(|v| v.len()).sum();
734 assert_eq!(total_nodes, 6);
735 assert!(!result.is_empty());
736 }
737
738 #[test]
739 fn cohesion_score_single_node() {
740 let g = build_graph(&["a"], &[]);
741 let score = cohesion_score(&g, &["a".to_string()]);
742 assert!((score - 1.0).abs() < f64::EPSILON);
743 }
744
745 #[test]
746 fn cohesion_score_complete_graph() {
747 let g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
748 let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
749 assert!((score - 1.0).abs() < f64::EPSILON);
750 }
751
752 #[test]
753 fn cohesion_score_no_edges() {
754 let g = build_graph(&["a", "b", "c"], &[]);
755 let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
756 assert!((score - 0.0).abs() < f64::EPSILON);
757 }
758
759 #[test]
760 fn cohesion_score_partial() {
761 let g = build_graph(&["a", "b", "c"], &[("a", "b")]);
762 let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
763 assert!((score - 0.33).abs() < 0.01);
764 }
765
766 #[test]
767 fn score_all_works() {
768 let g = build_graph(&["a", "b"], &[("a", "b")]);
769 let mut communities = HashMap::new();
770 communities.insert(0, vec!["a".to_string(), "b".to_string()]);
771 let scores = score_all(&g, &communities);
772 assert_eq!(scores.len(), 1);
773 assert!((scores[&0] - 1.0).abs() < f64::EPSILON);
774 }
775
776 #[test]
777 fn cluster_graph_mutates_communities() {
778 let mut g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
779 let result = cluster_graph(&mut g);
780 assert!(!result.is_empty());
781 assert!(!g.communities.is_empty());
782 }
783
784 #[test]
787 fn leiden_splits_disconnected_community() {
788 let g = build_graph(
791 &["a1", "a2", "a3", "b1", "b2", "b3"],
792 &[
793 ("a1", "a2"),
794 ("a2", "a3"),
795 ("a1", "a3"),
796 ("b1", "b2"),
797 ("b2", "b3"),
798 ("b1", "b3"),
799 ],
801 );
802 let result = cluster(&g);
803 assert_eq!(
805 result.len(),
806 2,
807 "disconnected cliques should form 2 communities"
808 );
809 for nodes in result.values() {
811 assert_eq!(nodes.len(), 3);
812 }
813 }
814
815 #[test]
816 fn leiden_connected_components_within() {
817 let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
818 for id in &["a", "b", "c", "d"] {
819 adj.insert(id.to_string(), Vec::new());
820 }
821 adj.get_mut("a").unwrap().push(("b".into(), 1.0));
823 adj.get_mut("b").unwrap().push(("a".into(), 1.0));
824 adj.get_mut("c").unwrap().push(("d".into(), 1.0));
825 adj.get_mut("d").unwrap().push(("c".into(), 1.0));
826
827 let members: Vec<String> = vec!["a", "b", "c", "d"]
828 .into_iter()
829 .map(String::from)
830 .collect();
831 let components = connected_components_within(&adj, &members);
832 assert_eq!(components.len(), 2);
833 }
834
835 #[test]
836 fn leiden_single_component() {
837 let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
838 for id in &["a", "b", "c"] {
839 adj.insert(id.to_string(), Vec::new());
840 }
841 adj.get_mut("a").unwrap().push(("b".into(), 1.0));
842 adj.get_mut("b").unwrap().push(("a".into(), 1.0));
843 adj.get_mut("b").unwrap().push(("c".into(), 1.0));
844 adj.get_mut("c").unwrap().push(("b".into(), 1.0));
845
846 let members: Vec<String> = vec!["a", "b", "c"].into_iter().map(String::from).collect();
847 let components = connected_components_within(&adj, &members);
848 assert_eq!(components.len(), 1);
849 }
850}