1use hashbrown::HashMap;
2use hashbrown::HashSet;
3use rayon::iter::IntoParallelIterator;
4use rayon::iter::ParallelIterator;
5use std::collections::VecDeque;
6
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9
10pub type CommunityId = u32;
11
12#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
13#[derive(Debug)]
14pub struct Graph<N, E> {
15 _nodes: Vec<N>,
16 _edges: Vec<EdgeInfo<E>>,
17 _connections: Vec<HashMap<usize, usize>>,
18 _total_weight: f32,
19}
20
21impl<N, E> Default for Graph<N, E> {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[derive(Debug)]
29pub struct EdgeInfo<E> {
30 pub edge_data: E,
31 pub weight: f32,
32 pub id: usize,
33}
34
35#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
36#[derive(Debug)]
37pub enum Community {
38 L1Community(HashSet<usize> ),
39 LNCommunity(Vec<Community> ),
40}
41
42impl Community {
43 pub fn collect_nodes(&self, f: &impl Fn(usize)) {
44 match self {
45 Community::L1Community(nodes) => {
46 for &node in nodes.iter() {
47 f(node);
48 }
49 }
50 Community::LNCommunity(communities) => {
51 for community in communities.iter() {
52 community.collect_nodes(f);
53 }
54 }
55 }
56 }
57}
58
59pub trait ModularityOptimizer {
60 fn is_converged(&mut self, previous: f32, current: f32) -> bool;
61 fn get_parallel_threshold(&self) -> usize;
62}
63
64pub struct TrivialModularityOptimizer {
65 pub parallel_scale: usize,
70
71 pub tol: f32,
76}
77
78impl ModularityOptimizer for TrivialModularityOptimizer {
79 #[inline]
80 fn is_converged(&mut self, previous: f32, current: f32) -> bool {
81 previous - current < self.tol
82 }
83
84 #[inline]
85 fn get_parallel_threshold(&self) -> usize {
86 self.parallel_scale
87 }
88}
89
90impl<N, E> Graph<N, E> {
91 pub const fn new() -> Self {
92 Self {
93 _nodes: Vec::new(),
94 _edges: Vec::new(),
95 _connections: Vec::new(),
96 _total_weight: 0.0,
97 }
98 }
99
100 pub fn node_data_slice(&self) -> &[N] {
101 &self._nodes
102 }
103
104 pub fn add_node(&mut self, node_data: N) -> usize {
105 let id = self._nodes.len();
106 self._nodes.push(node_data);
107 self._connections.push(HashMap::new());
108 id
109 }
110
111 pub fn add_edge(&mut self, n1: usize, n2: usize, edge_data: E, weight: f32) -> Option<usize> {
112 if n1 == n2 {
113 return None;
114 }
115
116 let conn = &self._connections[n1];
117
118 if let Some(edge_id) = conn.get(&n2) {
119 let edge_id = *edge_id;
120 let old = &mut self._edges[edge_id];
121 old.weight += weight;
122 self._total_weight += weight;
123 old.edge_data = edge_data;
124 return Some(edge_id);
125 }
126
127 let edge_id = self._edges.len();
128 let edge_info = EdgeInfo {
129 edge_data,
130 weight,
131 id: edge_id,
132 };
133 self._edges.push(edge_info);
134 self._connections[n1].insert(n2, edge_id);
135 self._connections[n2].insert(n1, edge_id);
136 self._total_weight += weight;
137 Some(edge_id)
138 }
139
140 pub fn count_nodes(&self) -> usize {
141 self._nodes.len()
142 }
143
144 pub fn try_get_edge_between(&self, n1: usize, n2: usize) -> Option<&EdgeInfo<E>> {
145 self._connections[n1]
146 .get(&n2)
147 .map(|edge_id| &self._edges[*edge_id])
148 }
149}
150
151pub struct LocalMove {
152 pub node: usize,
153 pub community: u32,
154}
155
156type CommunityAssignments = HashMap<usize, CommunityId>;
157
158trait MaybeLocalMove {
159 fn get(&self) -> Option<&LocalMove>;
160}
161
162impl MaybeLocalMove for LocalMove {
163 #[inline]
164 fn get(&self) -> Option<&LocalMove> {
165 Some(self)
166 }
167}
168
169impl MaybeLocalMove for () {
170 #[inline]
171 fn get(&self) -> Option<&LocalMove> {
172 None
173 }
174}
175
176impl<N: Send + Sync, E: Send + Sync> Graph<N, E> {
177 pub fn initial_community(&self) -> CommunityAssignments {
178 let count_nodes = self.count_nodes();
179 let mut assignments = HashMap::with_capacity(count_nodes);
180 for i in 0..CommunityId::try_from(count_nodes).expect("nodes must be less than u32::MAX") {
181 assignments.insert(i as usize, i);
182 }
183 assignments
184 }
185
186 #[inline]
187 pub fn compute_modularity(&self, assignments: &CommunityAssignments) -> f32 {
188 self._compute_modularity_impl(assignments, ())
189 }
190
191 #[inline]
192 pub fn compute_modularity_with_local_move(
193 &self,
194 assignments: &CommunityAssignments,
195 local_move: LocalMove,
196 ) -> f32 {
197 self._compute_modularity_impl(assignments, local_move)
198 }
199
200 #[inline]
201 fn _compute_modularity_impl(
202 &self,
203 assignments: &CommunityAssignments,
204 local_move: impl MaybeLocalMove,
205 ) -> f32 {
206 let m = self._total_weight;
207 let node_count: usize = self.count_nodes();
208 let mut q = 0.0;
209
210 macro_rules! get_assignment {
211 ($i:ident) => {
212 match local_move.get() {
213 None => assignments[&$i],
214 Some(local_move) => {
215 if local_move.node == $i {
216 local_move.community
217 } else {
218 assignments[&$i]
219 }
220 }
221 }
222 };
223 }
224
225 for i in 0..node_count {
226 let assigni = get_assignment!(i);
227 let conn_i = &self._connections[i];
228 let ki = conn_i.len() as f32;
229 for j in (i + 1)..node_count {
230 let assignj = get_assignment!(j);
231 if assigni != assignj {
232 continue;
233 }
234
235 let kj = self._connections[j].len() as f32;
236
237 match conn_i.get(&j) {
238 Some(edge_ij) => {
239 let edge_ij = *edge_ij;
240 let edge_ij_weight = self._edges[edge_ij].weight;
241 q += edge_ij_weight - (ki * kj) / (m + m);
242 }
243 None => {
244 q += -ki * kj / (m + m);
245 }
246 }
247 }
248 }
249
250 return q / m;
251 }
252
253 fn _optimize_modularity_handle_pitfall(
255 &self,
256 assignments: &mut CommunityAssignments,
257 current_modularity: f32,
258 ) {
259 let node_count = self.count_nodes();
260 for i in 0..node_count {
261 let node = i;
262 if let Some(local_move) = self.fast_local_move(node, assignments, current_modularity) {
263 assignments.insert(local_move.node, local_move.community);
264 }
265 }
266 }
267
268 fn optimize_modularity(
269 &self,
270 assignments: &mut CommunityAssignments,
271 optimizer: &mut impl ModularityOptimizer,
272 ) {
273 let mut current_modularity = self.compute_modularity(assignments);
274 let node_count = self.count_nodes();
275 let parallel_threshold = optimizer.get_parallel_threshold();
276 let mut previous_modularity: f32;
277
278 if node_count < parallel_threshold {
279 let mut batch_moving: Vec<LocalMove> = Vec::new();
280
281 loop {
282 previous_modularity = current_modularity;
283
284 for i in 0..node_count {
285 let node = i;
286 if let Some(local_move) =
287 self.fast_local_move(node, assignments, current_modularity)
288 {
289 batch_moving.push(local_move);
290 }
291 }
292
293 if batch_moving.is_empty() {
294 break;
295 } else {
296 for local_move in batch_moving.iter() {
297 assignments.insert(local_move.node, local_move.community);
298 }
299 batch_moving.clear();
300 }
301
302 current_modularity = self.compute_modularity(assignments);
303 if current_modularity == previous_modularity {
304 self._optimize_modularity_handle_pitfall(assignments, current_modularity);
307 }
308 if optimizer.is_converged(previous_modularity, current_modularity) {
309 break;
310 }
311 }
312 } else {
313 let mut batch_moving: boxcar::Vec<LocalMove> = boxcar::Vec::new();
314
315 loop {
316 previous_modularity = current_modularity;
317
318 (0..node_count).into_par_iter().for_each(|node| {
319 if let Some(local_move) =
320 self.fast_local_move(node, assignments, current_modularity)
321 {
322 batch_moving.push(local_move);
323 }
324 });
325
326 if batch_moving.is_empty() {
327 break;
328 } else {
329 for (_, local_move) in batch_moving.iter() {
330 assignments.insert(local_move.node, local_move.community);
331 }
332
333 batch_moving.clear();
334 }
335
336 current_modularity = self.compute_modularity(assignments);
337
338 if current_modularity == previous_modularity {
339 self._optimize_modularity_handle_pitfall(assignments, current_modularity);
342 }
343
344 if optimizer.is_converged(previous_modularity, current_modularity) {
345 break;
346 }
347 }
348 }
349 }
350
351 fn fast_local_move(
352 &self,
353 node: usize,
354 assignments: &CommunityAssignments,
355 current_modularity: f32,
356 ) -> Option<LocalMove> {
357 let neighbors = &self._connections[node];
358 let mut best_assign = assignments[&node];
359 let mut changed = false;
360 let mut current_modularity = current_modularity;
361
362 for &neighbor in neighbors.keys() {
363 let neighbor_assign = assignments[&neighbor];
364 if neighbor_assign == best_assign {
365 continue;
366 }
367
368 let new_modularity = self.compute_modularity_with_local_move(
369 assignments,
370 LocalMove {
371 node,
372 community: neighbor_assign,
373 },
374 );
375
376 if new_modularity > current_modularity {
377 best_assign = neighbor_assign;
378 current_modularity = new_modularity;
379 changed = true;
380 }
381 }
382
383 if changed {
384 Some(LocalMove {
385 node,
386 community: best_assign,
387 })
388 } else {
389 None
390 }
391 }
392
393 fn refine(&self, assignments: &CommunityAssignments) -> Vec<HashSet<usize>> {
394 let mut communities_by_louvain: Vec<HashSet<usize>> = vec![];
398
399 {
401 let mut relabel: HashMap<u32, u32> = HashMap::new();
402
403 let mut assure_relabel_community =
404 |communities_by_louvain: &mut Vec<HashSet<usize>>, louvain_community: u32| -> u32 {
405 match relabel.get(&louvain_community) {
406 Some(community) => *community,
407 None => {
408 let relabel_community = relabel.len() as u32;
409 #[cfg(debug_assertions)]
410 {
411 debug_assert!(
412 relabel_community as usize == communities_by_louvain.len()
413 );
414 }
415 relabel.insert(louvain_community, relabel_community);
416 communities_by_louvain.push(HashSet::new());
417 relabel_community
418 }
419 }
420 };
421
422 for (&node, &louvain_community) in assignments.iter() {
423 let relabel_community =
424 assure_relabel_community(&mut communities_by_louvain, louvain_community);
425 communities_by_louvain[relabel_community as usize].insert(node);
426 }
427 }
428
429 let mut i = 0;
432 while i < communities_by_louvain.len() {
433 let community = &communities_by_louvain[i];
434
435 if community.len() == 1 {
436 i += 1;
437 continue;
438 }
439
440 debug_assert!(community.len() > 1);
441
442 let mut left_members = community.clone();
443 let mut queue = VecDeque::new();
444
445 queue.push_back(*community.iter().next().unwrap());
446
447 while let Some(node) = queue.pop_front() {
448 let newly_visited = left_members.remove(&node);
449 if !newly_visited {
450 continue;
452 }
453
454 let neighbors = &self._connections[node];
455
456 for neighbor in neighbors.keys() {
457 if !community.contains(neighbor) {
458 continue;
460 }
461
462 queue.push_back(*neighbor);
463 }
464 }
465
466 if left_members.is_empty() {
467 } else {
469 let community = &mut communities_by_louvain[i];
470 for _ in community.extract_if(|node| left_members.contains(node)) {
472 }
474 if left_members.len() > community.len() {
479 std::mem::swap(&mut left_members, community);
480 }
481 communities_by_louvain.push(left_members);
482 }
483 i += 1;
484 }
485
486 communities_by_louvain
487 }
488
489 pub fn leiden(
490 &self,
491 max_iter: Option<usize>,
492 optimizer: &mut impl ModularityOptimizer,
493 ) -> Graph<Community, ()> {
494 let mut high_level_graph: Graph<Community, ()>;
495 {
496 let g = leiden_l1(&self, optimizer);
497 let node_count_g1 = g.count_nodes();
498 let g = leiden_ln(g, optimizer);
499 let node_count_g2 = g.count_nodes();
500 if node_count_g2 == node_count_g1 {
501 return g;
502 }
503 high_level_graph = g;
504 }
505
506 let mut count = high_level_graph.count_nodes();
507 let mut previous: usize;
508
509 if let Some(mut max_iter) = max_iter {
510 loop {
511 previous = count;
512 high_level_graph = leiden_ln(high_level_graph, optimizer);
513 count = high_level_graph.count_nodes();
514 if (previous == count) | (max_iter == 0) {
515 break;
516 }
517 max_iter -= 1;
518 }
519 } else {
520 loop {
521 previous = count;
522 high_level_graph = leiden_ln(high_level_graph, optimizer);
523 count = high_level_graph.count_nodes();
524 if previous == count {
525 break;
526 }
527 }
528 }
529
530 high_level_graph
531 }
532}
533
534fn leiden_l1<N: Send + Sync, E: Send + Sync>(
535 graph: &Graph<N, E>,
536 optimizer: &mut impl ModularityOptimizer,
537) -> Graph<Community, ()> {
538 let mut community_assignments = graph.initial_community();
539 graph.optimize_modularity(&mut community_assignments, optimizer);
540 let communities = graph.refine(&community_assignments);
541 return compress_l1(graph, communities);
542}
543
544fn leiden_ln(
545 graph: Graph<Community, ()>,
546 optimizer: &mut impl ModularityOptimizer,
547) -> Graph<Community, ()> {
548 let mut community_assignments = graph.initial_community();
549 graph.optimize_modularity(&mut community_assignments, optimizer);
550 let communities = graph.refine(&community_assignments);
551 if communities.len() == graph._nodes.len() {
552 return graph;
553 }
554 return compress_ln(graph, communities);
555}
556
557fn compress_l1<N, E>(
558 graph: &Graph<N, E>,
559 relabeled_assignments: Vec<HashSet<usize>>,
560) -> Graph<Community, ()> {
561 let mut node_to_community: HashMap<usize, u32> = HashMap::new();
562 let mut new_graph = Graph::new();
563
564 for (i, community) in relabeled_assignments.into_iter().enumerate() {
565 for &node in community.iter() {
566 node_to_community.insert(node, i as u32);
567 }
568
569 let new_community = Community::L1Community(community);
570
571 let node = new_graph.add_node(new_community);
572 debug_assert!(node == i);
573 let _ = node;
574 }
575
576 let node_count = graph.count_nodes();
577 for i in 0..node_count {
578 let assigni = node_to_community[&i];
579
580 for j in i + 1..node_count {
581 let assignj = node_to_community[&j];
582
583 if assigni == assignj {
584 continue;
585 }
586
587 if let Some(edge_info) = graph.try_get_edge_between(i, j) {
588 new_graph.add_edge(assigni as usize, assignj as usize, (), edge_info.weight);
589 }
590 }
591 }
592
593 new_graph
594}
595
596fn compress_ln<E>(
597 mut graph: Graph<Community, E>,
598 relabeled_assignments: Vec<HashSet<usize>>,
599) -> Graph<Community, ()> {
600 let mut node_to_community: HashMap<usize, u32> = HashMap::new();
601 let mut new_graph = Graph::new();
602
603 for (i, community) in relabeled_assignments.into_iter().enumerate() {
604 for &node in community.iter() {
605 node_to_community.insert(node, i as u32);
606 }
607
608 let new_community = if community.len() == 1 {
609 let &c = community.iter().next().unwrap();
610 let x = std::mem::replace(&mut graph._nodes[c], Community::LNCommunity(Vec::new()));
611 #[cfg(debug_assertions)]
612 {
613 if let Community::L1Community(sub_communities) = &x {
614 debug_assert!(sub_communities.len() >= 1);
615 }
616 }
617 x
618 } else {
619 let sub_communities: Vec<Community> = community
620 .into_iter()
621 .map(|c| {
622 let x =
623 std::mem::replace(&mut graph._nodes[c], Community::LNCommunity(Vec::new()));
625 x
626 })
627 .collect();
628 debug_assert!(sub_communities.len() >= 1);
629 Community::LNCommunity(sub_communities)
630 };
631
632 let node = new_graph.add_node(new_community);
633 debug_assert!(node == i);
634 let _ = node;
635 }
636
637 let node_count = graph.count_nodes();
638 for i in 0..node_count {
639 let assigni = node_to_community[&i];
640
641 for j in i + 1..node_count {
642 let assignj = node_to_community[&j];
643
644 if assigni == assignj {
645 continue;
646 }
647
648 if let Some(edge_info) = graph.try_get_edge_between(i, j) {
649 new_graph.add_edge(assigni as usize, assignj as usize, (), edge_info.weight);
650 }
651 }
652 }
653
654 new_graph
655}
656
657#[cfg(test)]
658mod tests {
659 use crate::{Graph, TrivialModularityOptimizer};
660 use std::cell::RefCell;
661 use std::collections::{HashMap, HashSet};
662
663 #[test]
664 fn test_example() {
665 let edges: &[(&'static str, &'static str, f32)] = &[
666 ("Fortran", "C", 0.5),
667 ("Fortran", "LISP", 0.3),
668 ("Fortran", "MATLAB", 0.6),
669 ("C", "C++", 0.9),
670 ("C", "Go", 0.6),
672 ("LISP", "ML", 0.5),
673 ("LISP", "OCaml", 0.2),
674 ("LISP", "Haskell", 0.2),
675 ("LISP", "Ruby", 0.5),
676 ("LISP", "Julia", 0.6),
677 ("ML", "OCaml", 0.8),
678 ("ML", "Haskell", 0.5),
679 ("OCaml", "Haskell", 0.3),
680 ("OCaml", "F#", 0.6),
681 ("Haskell", "Julia", 0.2),
682 ("C++", "Python", 0.32),
684 ("C++", "Ruby", 0.2),
685 ("C++", "C#", 0.5),
686 ("Python", "F#", 0.2),
692 ("Python", "Julia", 0.4),
693 ("C#", "F#", 0.3),
694 ];
695
696 let mut nodes: HashMap<&'static str, usize> = HashMap::new();
697 let mut g = Graph::new();
698 for (from, to, weight) in edges.iter() {
699 let from_id = *nodes.entry(from).or_insert_with(|| g.add_node(from));
700 let to_id = *nodes.entry(to).or_insert_with(|| g.add_node(to));
701 g.add_edge(from_id, to_id, (), *weight);
702 }
703
704 let mut optimizer = TrivialModularityOptimizer {
705 parallel_scale: 128,
706 tol: 1e-11,
707 };
708
709 let hierarchy = g.leiden(Some(100), &mut optimizer);
710 for (i, node) in hierarchy.node_data_slice().iter().enumerate() {
711 println!("community {}:", i);
712 node.collect_nodes(&|i| {
713 let n = g.node_data_slice()[i];
714 println!(" {}", n);
715 });
716 }
717 }
718
719 #[test]
720 fn test_simplest() {
721 let edges = vec![
722 (1, 2, 1.0),
723 (1, 3, 1.0),
724 (2, 3, 1.0),
725 (4, 5, 1.0),
726 (4, 6, 1.0),
727 (5, 6, 1.0),
728 (7, 8, 1.0),
729 (7, 9, 1.0),
730 (8, 9, 1.0),
731 ];
732
733 let mut nodes: HashMap<usize, usize> = HashMap::new();
734 let mut g = Graph::new();
735 for (from, to, weight) in edges.into_iter() {
736 let from_id = *nodes.entry(from).or_insert_with(|| g.add_node(from));
737 let to_id = *nodes.entry(to).or_insert_with(|| g.add_node(to));
738 g.add_edge(from_id, to_id, (), weight);
739 }
740
741 let mut optimizer = TrivialModularityOptimizer {
742 parallel_scale: 128,
743 tol: 1e-13,
744 };
745
746 let assignments = RefCell::new(g.initial_community());
747
748 let hierarchy = g.leiden(Some(100), &mut optimizer);
749 for (i, node) in hierarchy.node_data_slice().iter().enumerate() {
750 println!("community {}:", i);
751 let comm = i;
752 node.collect_nodes(&|i| {
753 assignments.borrow_mut().insert(i, comm as u32);
754 let n = g.node_data_slice()[i];
755 println!(" {}", n);
756 });
757 }
758
759 assert!(assignments.borrow().values().collect::<HashSet<_>>().len() == 3);
760
761 println!(
762 "real modularity: {}",
763 g.compute_modularity(&assignments.borrow())
764 );
765 }
766}