1use std::collections::{HashMap, HashSet, VecDeque};
9
10use tracing::debug;
11
12use graphify_core::graph::KnowledgeGraph;
13use graphify_core::model::CommunityInfo;
14
15const MAX_COMMUNITY_FRACTION: f64 = 0.25;
18
19const MIN_SPLIT_SIZE: usize = 10;
21
22pub fn cluster(graph: &KnowledgeGraph) -> HashMap<usize, Vec<String>> {
32 let node_count = graph.node_count();
33 if node_count == 0 {
34 return HashMap::new();
35 }
36
37 if graph.edge_count() == 0 {
39 return graph
40 .node_ids()
41 .into_iter()
42 .enumerate()
43 .map(|(i, id)| (i, vec![id]))
44 .collect();
45 }
46
47 let partition = leiden_partition(graph);
48
49 let mut communities: HashMap<usize, Vec<String>> = HashMap::new();
51 for (node_id, cid) in &partition {
52 communities.entry(*cid).or_default().push(node_id.clone());
53 }
54
55 let max_size = std::cmp::max(
57 MIN_SPLIT_SIZE,
58 (node_count as f64 * MAX_COMMUNITY_FRACTION) as usize,
59 );
60 let mut final_communities: Vec<Vec<String>> = Vec::new();
61 for nodes in communities.values() {
62 if nodes.len() > max_size {
63 final_communities.extend(split_community(graph, nodes));
64 } else {
65 final_communities.push(nodes.clone());
66 }
67 }
68
69 final_communities.sort_by_key(|b| std::cmp::Reverse(b.len()));
71 final_communities
72 .into_iter()
73 .enumerate()
74 .map(|(i, mut nodes)| {
75 nodes.sort();
76 (i, nodes)
77 })
78 .collect()
79}
80
81pub fn cluster_graph(graph: &mut KnowledgeGraph) -> HashMap<usize, Vec<String>> {
83 let communities = cluster(graph);
84
85 let scores = score_all(graph, &communities);
87 let mut infos: Vec<CommunityInfo> = communities
88 .iter()
89 .map(|(&cid, nodes)| CommunityInfo {
90 id: cid,
91 nodes: nodes.clone(),
92 cohesion: scores.get(&cid).copied().unwrap_or(0.0),
93 label: None,
94 })
95 .collect();
96 infos.sort_by_key(|c| c.id);
97 graph.communities = infos;
98
99 communities
100}
101
102pub fn cohesion_score(graph: &KnowledgeGraph, community_nodes: &[String]) -> f64 {
106 let n = community_nodes.len();
107 if n <= 1 {
108 return 1.0;
109 }
110
111 let node_set: HashSet<&str> = community_nodes.iter().map(|s| s.as_str()).collect();
112 let mut actual_edges = 0usize;
113
114 for node_id in community_nodes {
116 let neighbors = graph.get_neighbors(node_id);
117 for neighbor in &neighbors {
118 if node_set.contains(neighbor.id.as_str()) {
119 actual_edges += 1;
120 }
121 }
122 }
123 actual_edges /= 2;
125
126 let possible = n * (n - 1) / 2;
127 if possible == 0 {
128 return 0.0;
129 }
130 ((actual_edges as f64 / possible as f64) * 100.0).round() / 100.0
131}
132
133pub fn score_all(
135 graph: &KnowledgeGraph,
136 communities: &HashMap<usize, Vec<String>>,
137) -> HashMap<usize, f64> {
138 communities
139 .iter()
140 .map(|(&cid, nodes)| (cid, cohesion_score(graph, nodes)))
141 .collect()
142}
143
144fn build_adjacency(graph: &KnowledgeGraph) -> HashMap<String, Vec<(String, f64)>> {
150 let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
151 for id in graph.node_ids() {
152 adj.entry(id).or_default();
153 }
154 for (src, tgt, edge) in graph.edges_with_endpoints() {
155 adj.entry(src.to_string())
156 .or_default()
157 .push((tgt.to_string(), edge.weight));
158 adj.entry(tgt.to_string())
159 .or_default()
160 .push((src.to_string(), edge.weight));
161 }
162 adj
163}
164
165fn total_weight(adj: &HashMap<String, Vec<(String, f64)>>) -> f64 {
167 let mut m = 0.0;
168 for neighbors in adj.values() {
169 for (_, w) in neighbors {
170 m += w;
171 }
172 }
173 m / 2.0 }
175
176fn node_strength(adj: &HashMap<String, Vec<(String, f64)>>, node: &str) -> f64 {
178 adj.get(node)
179 .map(|neighbors| neighbors.iter().map(|(_, w)| w).sum())
180 .unwrap_or(0.0)
181}
182
183fn edges_to_community(
185 adj: &HashMap<String, Vec<(String, f64)>>,
186 node: &str,
187 community: &HashSet<&str>,
188) -> f64 {
189 adj.get(node)
190 .map(|neighbors| {
191 neighbors
192 .iter()
193 .filter(|(n, _)| community.contains(n.as_str()))
194 .map(|(_, w)| w)
195 .sum()
196 })
197 .unwrap_or(0.0)
198}
199
200fn community_strength(adj: &HashMap<String, Vec<(String, f64)>>, members: &HashSet<&str>) -> f64 {
202 members.iter().map(|n| node_strength(adj, n)).sum()
203}
204
205fn leiden_partition(graph: &KnowledgeGraph) -> HashMap<String, usize> {
214 let adj = build_adjacency(graph);
215 let m = total_weight(&adj);
216 if m == 0.0 {
217 return graph
218 .node_ids()
219 .into_iter()
220 .enumerate()
221 .map(|(i, id)| (id, i))
222 .collect();
223 }
224
225 let node_ids = graph.node_ids();
226
227 let mut community_of: HashMap<String, usize> = node_ids
229 .iter()
230 .enumerate()
231 .map(|(i, id)| (id.clone(), i))
232 .collect();
233
234 let max_outer_iterations = 10;
235 for _outer in 0..max_outer_iterations {
236 let changed = louvain_phase(&adj, &node_ids, &mut community_of, m);
238
239 let refined = refinement_phase(&adj, &mut community_of, m);
241
242 if !changed && !refined {
243 break;
244 }
245 }
246
247 compact_ids(&mut community_of);
249 community_of
250}
251
252fn louvain_phase(
257 adj: &HashMap<String, Vec<(String, f64)>>,
258 node_ids: &[String],
259 community_of: &mut HashMap<String, usize>,
260 m: f64,
261) -> bool {
262 let mut community_members: HashMap<usize, HashSet<String>> = HashMap::new();
263 for (node, &cid) in community_of.iter() {
264 community_members
265 .entry(cid)
266 .or_default()
267 .insert(node.clone());
268 }
269
270 let max_iterations = 50;
271 let mut any_changed = false;
272
273 for _iteration in 0..max_iterations {
274 let mut improved = false;
275
276 for node in node_ids {
277 let current_community = community_of[node];
278 let ki = node_strength(adj, node);
279
280 let mut neighbor_communities: HashSet<usize> = HashSet::new();
282 if let Some(neighbors) = adj.get(node.as_str()) {
283 for (n, _) in neighbors {
284 neighbor_communities.insert(community_of[n]);
285 }
286 }
287 neighbor_communities.insert(current_community);
288
289 let mut best_community = current_community;
290 let mut best_gain = 0.0f64;
291
292 for &target_community in &neighbor_communities {
293 if target_community == current_community {
294 continue;
295 }
296
297 let members_ref: HashSet<&str> = community_members
298 .get(&target_community)
299 .map(|s| s.iter().map(|x| x.as_str()).collect())
300 .unwrap_or_default();
301
302 let current_members_ref: HashSet<&str> = community_members
303 .get(¤t_community)
304 .map(|s| {
305 s.iter()
306 .filter(|x| x.as_str() != node.as_str())
307 .map(|x| x.as_str())
308 .collect()
309 })
310 .unwrap_or_default();
311
312 let ki_in_target = edges_to_community(adj, node, &members_ref);
313 let ki_in_current = edges_to_community(adj, node, ¤t_members_ref);
314 let sigma_target = community_strength(adj, &members_ref);
315 let sigma_current = community_strength(adj, ¤t_members_ref);
316
317 let gain = (ki_in_target - ki_in_current) / m
318 - ki * (sigma_target - sigma_current) / (2.0 * m * m);
319
320 if gain > best_gain {
321 best_gain = gain;
322 best_community = target_community;
323 }
324 }
325
326 if best_community != current_community {
327 community_members
328 .get_mut(¤t_community)
329 .unwrap()
330 .remove(node);
331 community_members
332 .entry(best_community)
333 .or_default()
334 .insert(node.clone());
335 community_of.insert(node.clone(), best_community);
336 improved = true;
337 any_changed = true;
338 }
339 }
340
341 if !improved {
342 break;
343 }
344 }
345
346 any_changed
347}
348
349fn refinement_phase(
357 adj: &HashMap<String, Vec<(String, f64)>>,
358 community_of: &mut HashMap<String, usize>,
359 m: f64,
360) -> bool {
361 let mut community_members: HashMap<usize, Vec<String>> = HashMap::new();
363 for (node, &cid) in community_of.iter() {
364 community_members.entry(cid).or_default().push(node.clone());
365 }
366
367 let mut any_refined = false;
368 let mut next_cid = community_members.keys().copied().max().unwrap_or(0) + 1;
369
370 let community_ids: Vec<usize> = community_members.keys().copied().collect();
371 for cid in community_ids {
372 let members = match community_members.get(&cid) {
373 Some(m) if m.len() > 1 => m.clone(),
374 _ => continue,
375 };
376
377 let components = connected_components_within(adj, &members);
379 if components.len() <= 1 {
380 continue; }
382
383 debug!(
384 "Leiden refinement: community {} has {} disconnected components, splitting",
385 cid,
386 components.len()
387 );
388
389 let mut sorted_components = components;
393 sorted_components.sort_by_key(|c| std::cmp::Reverse(c.len()));
394
395 for component in sorted_components.iter().skip(1) {
397 let mut neighbor_cid_edges: HashMap<usize, f64> = HashMap::new();
399 for node in component {
400 if let Some(neighbors) = adj.get(node.as_str()) {
401 for (nbr, w) in neighbors {
402 let nbr_cid = community_of[nbr];
403 if nbr_cid != cid {
404 *neighbor_cid_edges.entry(nbr_cid).or_default() += w;
405 }
406 }
407 }
408 }
409
410 let target_cid = if let Some((&best_cid, _)) = neighbor_cid_edges
413 .iter()
414 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
415 {
416 let _component_set: HashSet<&str> = component.iter().map(|s| s.as_str()).collect();
418 let target_members: HashSet<&str> = community_members
419 .get(&best_cid)
420 .map(|s| s.iter().map(|x| x.as_str()).collect())
421 .unwrap_or_default();
422
423 let ki_sum: f64 = component.iter().map(|n| node_strength(adj, n)).sum();
424 let ki_in = component
425 .iter()
426 .map(|n| edges_to_community(adj, n, &target_members))
427 .sum::<f64>();
428 let sigma_t = community_strength(adj, &target_members);
429
430 let gain = ki_in / m - ki_sum * sigma_t / (2.0 * m * m);
431 if gain > 0.0 {
432 best_cid
433 } else {
434 let new_cid = next_cid;
435 next_cid += 1;
436 new_cid
437 }
438 } else {
439 let new_cid = next_cid;
440 next_cid += 1;
441 new_cid
442 };
443
444 for node in component {
446 community_of.insert(node.clone(), target_cid);
447 community_members
448 .entry(target_cid)
449 .or_default()
450 .push(node.clone());
451 }
452 any_refined = true;
453 }
454
455 if any_refined {
457 community_members.insert(cid, sorted_components.into_iter().next().unwrap());
458 }
459 }
460
461 any_refined
462}
463
464fn connected_components_within(
466 adj: &HashMap<String, Vec<(String, f64)>>,
467 members: &[String],
468) -> Vec<Vec<String>> {
469 let member_set: HashSet<&str> = members.iter().map(|s| s.as_str()).collect();
470 let mut visited: HashSet<&str> = HashSet::new();
471 let mut components: Vec<Vec<String>> = Vec::new();
472
473 for node in members {
474 if visited.contains(node.as_str()) {
475 continue;
476 }
477
478 let mut component = Vec::new();
479 let mut queue = VecDeque::new();
480 queue.push_back(node.as_str());
481 visited.insert(node.as_str());
482
483 while let Some(current) = queue.pop_front() {
484 component.push(current.to_string());
485 if let Some(neighbors) = adj.get(current) {
486 for (nbr, _) in neighbors {
487 if member_set.contains(nbr.as_str()) && !visited.contains(nbr.as_str()) {
488 visited.insert(nbr.as_str());
489 queue.push_back(nbr.as_str());
490 }
491 }
492 }
493 }
494
495 components.push(component);
496 }
497
498 components
499}
500
501fn compact_ids(community_of: &mut HashMap<String, usize>) {
503 let mut used: Vec<usize> = community_of
504 .values()
505 .copied()
506 .collect::<HashSet<_>>()
507 .into_iter()
508 .collect();
509 used.sort();
510 let remap: HashMap<usize, usize> = used
511 .iter()
512 .enumerate()
513 .map(|(new_id, &old_id)| (old_id, new_id))
514 .collect();
515 for cid in community_of.values_mut() {
516 *cid = remap[cid];
517 }
518}
519
520fn split_community(graph: &KnowledgeGraph, nodes: &[String]) -> Vec<Vec<String>> {
522 if nodes.len() < MIN_SPLIT_SIZE {
523 return vec![nodes.to_vec()];
524 }
525
526 let node_set: HashSet<&str> = nodes.iter().map(|s| s.as_str()).collect();
527
528 let mut sub_adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
530 for node in nodes {
531 sub_adj.entry(node.clone()).or_default();
532 }
533 for (src, tgt, edge) in graph.edges_with_endpoints() {
534 if node_set.contains(src) && node_set.contains(tgt) {
535 sub_adj
536 .entry(src.to_string())
537 .or_default()
538 .push((tgt.to_string(), edge.weight));
539 sub_adj
540 .entry(tgt.to_string())
541 .or_default()
542 .push((src.to_string(), edge.weight));
543 }
544 }
545
546 let m = total_weight(&sub_adj);
547 if m == 0.0 {
548 return nodes.iter().map(|n| vec![n.clone()]).collect();
549 }
550
551 let mut community_of: HashMap<String, usize> = nodes
553 .iter()
554 .enumerate()
555 .map(|(i, id)| (id.clone(), i))
556 .collect();
557
558 let node_list: Vec<String> = nodes.to_vec();
559 for _ in 0..5 {
560 let changed = louvain_phase(&sub_adj, &node_list, &mut community_of, m);
561 let refined = refinement_phase(&sub_adj, &mut community_of, m);
562 if !changed && !refined {
563 break;
564 }
565 }
566
567 let mut groups: HashMap<usize, Vec<String>> = HashMap::new();
569 for (node, cid) in &community_of {
570 groups.entry(*cid).or_default().push(node.clone());
571 }
572
573 let result: Vec<Vec<String>> = groups.into_values().filter(|s| !s.is_empty()).collect();
574
575 if result.len() <= 1 {
576 debug!("could not split community of {} nodes further", nodes.len());
577 return vec![nodes.to_vec()];
578 }
579
580 result
581}
582
583#[cfg(test)]
588mod tests {
589 use super::*;
590 use graphify_core::confidence::Confidence;
591 use graphify_core::graph::KnowledgeGraph;
592 use graphify_core::model::{GraphEdge, GraphNode, NodeType};
593 use std::collections::HashMap as StdMap;
594
595 fn make_node(id: &str) -> GraphNode {
596 GraphNode {
597 id: id.into(),
598 label: id.into(),
599 source_file: "test.rs".into(),
600 source_location: None,
601 node_type: NodeType::Class,
602 community: None,
603 extra: StdMap::new(),
604 }
605 }
606
607 fn make_edge(src: &str, tgt: &str) -> GraphEdge {
608 GraphEdge {
609 source: src.into(),
610 target: tgt.into(),
611 relation: "calls".into(),
612 confidence: Confidence::Extracted,
613 confidence_score: 1.0,
614 source_file: "test.rs".into(),
615 source_location: None,
616 weight: 1.0,
617 extra: StdMap::new(),
618 }
619 }
620
621 fn build_graph(nodes: &[&str], edges: &[(&str, &str)]) -> KnowledgeGraph {
622 let mut g = KnowledgeGraph::new();
623 for &id in nodes {
624 g.add_node(make_node(id)).unwrap();
625 }
626 for &(s, t) in edges {
627 g.add_edge(make_edge(s, t)).unwrap();
628 }
629 g
630 }
631
632 #[test]
633 fn cluster_empty_graph() {
634 let g = KnowledgeGraph::new();
635 let result = cluster(&g);
636 assert!(result.is_empty());
637 }
638
639 #[test]
640 fn cluster_no_edges() {
641 let g = build_graph(&["a", "b", "c"], &[]);
642 let result = cluster(&g);
643 assert_eq!(result.len(), 3);
644 for nodes in result.values() {
645 assert_eq!(nodes.len(), 1);
646 }
647 }
648
649 #[test]
650 fn cluster_single_clique() {
651 let g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
652 let result = cluster(&g);
653 let total_nodes: usize = result.values().map(|v| v.len()).sum();
654 assert_eq!(total_nodes, 3);
655 assert!(result.len() <= 3);
656 }
657
658 #[test]
659 fn cluster_two_cliques() {
660 let g = build_graph(
661 &["a1", "a2", "a3", "b1", "b2", "b3"],
662 &[
663 ("a1", "a2"),
664 ("a2", "a3"),
665 ("a1", "a3"),
666 ("b1", "b2"),
667 ("b2", "b3"),
668 ("b1", "b3"),
669 ("a3", "b1"), ],
671 );
672 let result = cluster(&g);
673 let total_nodes: usize = result.values().map(|v| v.len()).sum();
674 assert_eq!(total_nodes, 6);
675 assert!(!result.is_empty());
676 }
677
678 #[test]
679 fn cohesion_score_single_node() {
680 let g = build_graph(&["a"], &[]);
681 let score = cohesion_score(&g, &["a".to_string()]);
682 assert!((score - 1.0).abs() < f64::EPSILON);
683 }
684
685 #[test]
686 fn cohesion_score_complete_graph() {
687 let g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
688 let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
689 assert!((score - 1.0).abs() < f64::EPSILON);
690 }
691
692 #[test]
693 fn cohesion_score_no_edges() {
694 let g = build_graph(&["a", "b", "c"], &[]);
695 let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
696 assert!((score - 0.0).abs() < f64::EPSILON);
697 }
698
699 #[test]
700 fn cohesion_score_partial() {
701 let g = build_graph(&["a", "b", "c"], &[("a", "b")]);
702 let score = cohesion_score(&g, &["a".to_string(), "b".to_string(), "c".to_string()]);
703 assert!((score - 0.33).abs() < 0.01);
704 }
705
706 #[test]
707 fn score_all_works() {
708 let g = build_graph(&["a", "b"], &[("a", "b")]);
709 let mut communities = HashMap::new();
710 communities.insert(0, vec!["a".to_string(), "b".to_string()]);
711 let scores = score_all(&g, &communities);
712 assert_eq!(scores.len(), 1);
713 assert!((scores[&0] - 1.0).abs() < f64::EPSILON);
714 }
715
716 #[test]
717 fn cluster_graph_mutates_communities() {
718 let mut g = build_graph(&["a", "b", "c"], &[("a", "b"), ("b", "c"), ("a", "c")]);
719 let result = cluster_graph(&mut g);
720 assert!(!result.is_empty());
721 assert!(!g.communities.is_empty());
722 }
723
724 #[test]
727 fn leiden_splits_disconnected_community() {
728 let g = build_graph(
731 &["a1", "a2", "a3", "b1", "b2", "b3"],
732 &[
733 ("a1", "a2"),
734 ("a2", "a3"),
735 ("a1", "a3"),
736 ("b1", "b2"),
737 ("b2", "b3"),
738 ("b1", "b3"),
739 ],
741 );
742 let result = cluster(&g);
743 assert_eq!(
745 result.len(),
746 2,
747 "disconnected cliques should form 2 communities"
748 );
749 for nodes in result.values() {
751 assert_eq!(nodes.len(), 3);
752 }
753 }
754
755 #[test]
756 fn leiden_connected_components_within() {
757 let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
758 for id in &["a", "b", "c", "d"] {
759 adj.insert(id.to_string(), Vec::new());
760 }
761 adj.get_mut("a").unwrap().push(("b".into(), 1.0));
763 adj.get_mut("b").unwrap().push(("a".into(), 1.0));
764 adj.get_mut("c").unwrap().push(("d".into(), 1.0));
765 adj.get_mut("d").unwrap().push(("c".into(), 1.0));
766
767 let members: Vec<String> = vec!["a", "b", "c", "d"]
768 .into_iter()
769 .map(String::from)
770 .collect();
771 let components = connected_components_within(&adj, &members);
772 assert_eq!(components.len(), 2);
773 }
774
775 #[test]
776 fn leiden_single_component() {
777 let mut adj: HashMap<String, Vec<(String, f64)>> = HashMap::new();
778 for id in &["a", "b", "c"] {
779 adj.insert(id.to_string(), Vec::new());
780 }
781 adj.get_mut("a").unwrap().push(("b".into(), 1.0));
782 adj.get_mut("b").unwrap().push(("a".into(), 1.0));
783 adj.get_mut("b").unwrap().push(("c".into(), 1.0));
784 adj.get_mut("c").unwrap().push(("b".into(), 1.0));
785
786 let members: Vec<String> = vec!["a", "b", "c"].into_iter().map(String::from).collect();
787 let components = connected_components_within(&adj, &members);
788 assert_eq!(components.len(), 1);
789 }
790}