fa_leiden_cd/
lib.rs

1use hashbrown::HashMap;
2use hashbrown::HashSet;
3use rayon::iter::IntoParallelIterator;
4use rayon::iter::ParallelIterator;
5use std::collections::VecDeque;
6
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9
10pub type CommunityId = u32;
11
12#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
13#[derive(Debug)]
14pub struct Graph<N, E> {
15    _nodes: Vec<N>,
16    _edges: Vec<EdgeInfo<E>>,
17    _connections: Vec<HashMap<usize, usize>>,
18    _total_weight: f32,
19}
20
21impl<N, E> Default for Graph<N, E> {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[derive(Debug)]
29pub struct EdgeInfo<E> {
30    pub edge_data: E,
31    pub weight: f32,
32    pub id: usize,
33}
34
35#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
36#[derive(Debug)]
37pub enum Community {
38    L1Community(HashSet<usize> /* nodes */),
39    LNCommunity(Vec<Community> /* communities */),
40}
41
42impl Community {
43    pub fn collect_nodes(&self, f: &impl Fn(usize)) {
44        match self {
45            Community::L1Community(nodes) => {
46                for &node in nodes.iter() {
47                    f(node);
48                }
49            }
50            Community::LNCommunity(communities) => {
51                for community in communities.iter() {
52                    community.collect_nodes(f);
53                }
54            }
55        }
56    }
57}
58
59pub trait ModularityOptimizer {
60    fn is_converged(&mut self, previous: f32, current: f32) -> bool;
61    fn get_parallel_threshold(&self) -> usize;
62}
63
64pub struct TrivialModularityOptimizer {
65    /// Parallel scale
66    ///
67    /// If the node count exceeds this value, the optimizer will use parallel
68    /// optimization.
69    pub parallel_scale: usize,
70
71    /// Tolerance for modularity change
72    ///
73    /// If the modularity change is less than this value, the optimizer will
74    /// consider the optimization converged.
75    pub tol: f32,
76}
77
78impl ModularityOptimizer for TrivialModularityOptimizer {
79    #[inline]
80    fn is_converged(&mut self, previous: f32, current: f32) -> bool {
81        previous - current < self.tol
82    }
83
84    #[inline]
85    fn get_parallel_threshold(&self) -> usize {
86        self.parallel_scale
87    }
88}
89
90impl<N, E> Graph<N, E> {
91    pub const fn new() -> Self {
92        Self {
93            _nodes: Vec::new(),
94            _edges: Vec::new(),
95            _connections: Vec::new(),
96            _total_weight: 0.0,
97        }
98    }
99
100    pub fn node_data_slice(&self) -> &[N] {
101        &self._nodes
102    }
103
104    pub fn add_node(&mut self, node_data: N) -> usize {
105        let id = self._nodes.len();
106        self._nodes.push(node_data);
107        self._connections.push(HashMap::new());
108        id
109    }
110
111    pub fn add_edge(&mut self, n1: usize, n2: usize, edge_data: E, weight: f32) -> Option<usize> {
112        if n1 == n2 {
113            return None;
114        }
115
116        let conn = &self._connections[n1];
117
118        if let Some(edge_id) = conn.get(&n2) {
119            let edge_id = *edge_id;
120            let old = &mut self._edges[edge_id];
121            old.weight += weight;
122            self._total_weight += weight;
123            old.edge_data = edge_data;
124            return Some(edge_id);
125        }
126
127        let edge_id = self._edges.len();
128        let edge_info = EdgeInfo {
129            edge_data,
130            weight,
131            id: edge_id,
132        };
133        self._edges.push(edge_info);
134        self._connections[n1].insert(n2, edge_id);
135        self._connections[n2].insert(n1, edge_id);
136        self._total_weight += weight;
137        Some(edge_id)
138    }
139
140    pub fn count_nodes(&self) -> usize {
141        self._nodes.len()
142    }
143
144    pub fn try_get_edge_between(&self, n1: usize, n2: usize) -> Option<&EdgeInfo<E>> {
145        self._connections[n1]
146            .get(&n2)
147            .map(|edge_id| &self._edges[*edge_id])
148    }
149}
150
151pub struct LocalMove {
152    pub node: usize,
153    pub community: u32,
154}
155
156type CommunityAssignments = HashMap<usize, CommunityId>;
157
158trait MaybeLocalMove {
159    fn get(&self) -> Option<&LocalMove>;
160}
161
162impl MaybeLocalMove for LocalMove {
163    #[inline]
164    fn get(&self) -> Option<&LocalMove> {
165        Some(self)
166    }
167}
168
169impl MaybeLocalMove for () {
170    #[inline]
171    fn get(&self) -> Option<&LocalMove> {
172        None
173    }
174}
175
176impl<N: Send + Sync, E: Send + Sync> Graph<N, E> {
177    pub fn initial_community(&self) -> CommunityAssignments {
178        let count_nodes = self.count_nodes();
179        let mut assignments = HashMap::with_capacity(count_nodes);
180        for i in 0..CommunityId::try_from(count_nodes).expect("nodes must be less than u32::MAX") {
181            assignments.insert(i as usize, i);
182        }
183        assignments
184    }
185
186    #[inline]
187    pub fn compute_modularity(&self, assignments: &CommunityAssignments) -> f32 {
188        self._compute_modularity_impl(assignments, ())
189    }
190
191    #[inline]
192    pub fn compute_modularity_with_local_move(
193        &self,
194        assignments: &CommunityAssignments,
195        local_move: LocalMove,
196    ) -> f32 {
197        self._compute_modularity_impl(assignments, local_move)
198    }
199
200    #[inline]
201    fn _compute_modularity_impl(
202        &self,
203        assignments: &CommunityAssignments,
204        local_move: impl MaybeLocalMove,
205    ) -> f32 {
206        let m = self._total_weight;
207        let node_count: usize = self.count_nodes();
208        let mut q = 0.0;
209
210        macro_rules! get_assignment {
211            ($i:ident) => {
212                match local_move.get() {
213                    None => assignments[&$i],
214                    Some(local_move) => {
215                        if local_move.node == $i {
216                            local_move.community
217                        } else {
218                            assignments[&$i]
219                        }
220                    }
221                }
222            };
223        }
224
225        for i in 0..node_count {
226            let assigni = get_assignment!(i);
227            let conn_i = &self._connections[i];
228            let ki = conn_i.len() as f32;
229            for j in (i + 1)..node_count {
230                let assignj = get_assignment!(j);
231                if assigni != assignj {
232                    continue;
233                }
234
235                let kj = self._connections[j].len() as f32;
236
237                match conn_i.get(&j) {
238                    Some(edge_ij) => {
239                        let edge_ij = *edge_ij;
240                        let edge_ij_weight = self._edges[edge_ij].weight;
241                        q += edge_ij_weight - (ki * kj) / (m + m);
242                    }
243                    None => {
244                        q += -ki * kj / (m + m);
245                    }
246                }
247            }
248        }
249
250        return q / m;
251    }
252
253    /// Move a single node to avoid local minimal
254    fn _optimize_modularity_handle_pitfall(
255        &self,
256        assignments: &mut CommunityAssignments,
257        current_modularity: f32,
258    ) {
259        let node_count = self.count_nodes();
260        for i in 0..node_count {
261            let node = i;
262            if let Some(local_move) = self.fast_local_move(node, assignments, current_modularity) {
263                assignments.insert(local_move.node, local_move.community);
264            }
265        }
266    }
267
268    fn optimize_modularity(
269        &self,
270        assignments: &mut CommunityAssignments,
271        optimizer: &mut impl ModularityOptimizer,
272    ) {
273        let mut current_modularity = self.compute_modularity(assignments);
274        let node_count = self.count_nodes();
275        let parallel_threshold = optimizer.get_parallel_threshold();
276        let mut previous_modularity: f32;
277
278        if node_count < parallel_threshold {
279            let mut batch_moving: Vec<LocalMove> = Vec::new();
280
281            loop {
282                previous_modularity = current_modularity;
283
284                for i in 0..node_count {
285                    let node = i;
286                    if let Some(local_move) =
287                        self.fast_local_move(node, assignments, current_modularity)
288                    {
289                        batch_moving.push(local_move);
290                    }
291                }
292
293                if batch_moving.is_empty() {
294                    break;
295                } else {
296                    for local_move in batch_moving.iter() {
297                        assignments.insert(local_move.node, local_move.community);
298                    }
299                    batch_moving.clear();
300                }
301
302                current_modularity = self.compute_modularity(assignments);
303                if current_modularity == previous_modularity {
304                    // but batch_moving is not empty
305                    // in this case we randomly choose a node to move to avoid local minimal
306                    self._optimize_modularity_handle_pitfall(assignments, current_modularity);
307                }
308                if optimizer.is_converged(previous_modularity, current_modularity) {
309                    break;
310                }
311            }
312        } else {
313            let mut batch_moving: boxcar::Vec<LocalMove> = boxcar::Vec::new();
314
315            loop {
316                previous_modularity = current_modularity;
317
318                (0..node_count).into_par_iter().for_each(|node| {
319                    if let Some(local_move) =
320                        self.fast_local_move(node, assignments, current_modularity)
321                    {
322                        batch_moving.push(local_move);
323                    }
324                });
325
326                if batch_moving.is_empty() {
327                    break;
328                } else {
329                    for (_, local_move) in batch_moving.iter() {
330                        assignments.insert(local_move.node, local_move.community);
331                    }
332
333                    batch_moving.clear();
334                }
335
336                current_modularity = self.compute_modularity(assignments);
337
338                if current_modularity == previous_modularity {
339                    // but batch_moving is not empty
340                    // in this case we randomly choose a node to move to avoid local minimal
341                    self._optimize_modularity_handle_pitfall(assignments, current_modularity);
342                }
343
344                if optimizer.is_converged(previous_modularity, current_modularity) {
345                    break;
346                }
347            }
348        }
349    }
350
351    fn fast_local_move(
352        &self,
353        node: usize,
354        assignments: &CommunityAssignments,
355        current_modularity: f32,
356    ) -> Option<LocalMove> {
357        let neighbors = &self._connections[node];
358        let mut best_assign = assignments[&node];
359        let mut changed = false;
360        let mut current_modularity = current_modularity;
361
362        for &neighbor in neighbors.keys() {
363            let neighbor_assign = assignments[&neighbor];
364            if neighbor_assign == best_assign {
365                continue;
366            }
367
368            let new_modularity = self.compute_modularity_with_local_move(
369                assignments,
370                LocalMove {
371                    node,
372                    community: neighbor_assign,
373                },
374            );
375
376            if new_modularity > current_modularity {
377                best_assign = neighbor_assign;
378                current_modularity = new_modularity;
379                changed = true;
380            }
381        }
382
383        if changed {
384            Some(LocalMove {
385                node,
386                community: best_assign,
387            })
388        } else {
389            None
390        }
391    }
392
393    fn refine(&self, assignments: &CommunityAssignments) -> Vec<HashSet<usize>> {
394        // this is the community assignments by louvain
395        // each community might get split into multiple communities
396        // if there are partitions that are not connected to each other
397        let mut communities_by_louvain: Vec<HashSet<usize>> = vec![];
398
399        // fill and relabel
400        {
401            let mut relabel: HashMap<u32, u32> = HashMap::new();
402
403            let mut assure_relabel_community =
404                |communities_by_louvain: &mut Vec<HashSet<usize>>, louvain_community: u32| -> u32 {
405                    match relabel.get(&louvain_community) {
406                        Some(community) => *community,
407                        None => {
408                            let relabel_community = relabel.len() as u32;
409                            #[cfg(debug_assertions)]
410                            {
411                                debug_assert!(
412                                    relabel_community as usize == communities_by_louvain.len()
413                                );
414                            }
415                            relabel.insert(louvain_community, relabel_community);
416                            communities_by_louvain.push(HashSet::new());
417                            relabel_community
418                        }
419                    }
420                };
421
422            for (&node, &louvain_community) in assignments.iter() {
423                let relabel_community =
424                    assure_relabel_community(&mut communities_by_louvain, louvain_community);
425                communities_by_louvain[relabel_community as usize].insert(node);
426            }
427        }
428
429        // XXX: parallelize?
430        // validate the inner connections in each community
431        let mut i = 0;
432        while i < communities_by_louvain.len() {
433            let community = &communities_by_louvain[i];
434
435            if community.len() == 1 {
436                i += 1;
437                continue;
438            }
439
440            debug_assert!(community.len() > 1);
441
442            let mut left_members = community.clone();
443            let mut queue = VecDeque::new();
444
445            queue.push_back(*community.iter().next().unwrap());
446
447            while let Some(node) = queue.pop_front() {
448                let newly_visited = left_members.remove(&node);
449                if !newly_visited {
450                    // already visited, skip
451                    continue;
452                }
453
454                let neighbors = &self._connections[node];
455
456                for neighbor in neighbors.keys() {
457                    if !community.contains(neighbor) {
458                        // the sub-community shall not get connected via this node
459                        continue;
460                    }
461
462                    queue.push_back(*neighbor);
463                }
464            }
465
466            if left_members.is_empty() {
467                // all members are connected, no need to split
468            } else {
469                let community = &mut communities_by_louvain[i];
470                // split the community into two
471                for _ in community.extract_if(|node| left_members.contains(node)) {
472                    /* force consume to perform the elimination */
473                }
474                // optimization:
475                // we already know that `left_members` are connected, and
476                // if `left_members` are larger, we swap them as `communities_by_louvain[i]`
477                // so that it will not be resolved in the later rounds.
478                if left_members.len() > community.len() {
479                    std::mem::swap(&mut left_members, community);
480                }
481                communities_by_louvain.push(left_members);
482            }
483            i += 1;
484        }
485
486        communities_by_louvain
487    }
488
489    pub fn leiden(
490        &self,
491        max_iter: Option<usize>,
492        optimizer: &mut impl ModularityOptimizer,
493    ) -> Graph<Community, ()> {
494        let mut high_level_graph: Graph<Community, ()>;
495        {
496            let g = leiden_l1(&self, optimizer);
497            let node_count_g1 = g.count_nodes();
498            let g = leiden_ln(g, optimizer);
499            let node_count_g2 = g.count_nodes();
500            if node_count_g2 == node_count_g1 {
501                return g;
502            }
503            high_level_graph = g;
504        }
505
506        let mut count = high_level_graph.count_nodes();
507        let mut previous: usize;
508
509        if let Some(mut max_iter) = max_iter {
510            loop {
511                previous = count;
512                high_level_graph = leiden_ln(high_level_graph, optimizer);
513                count = high_level_graph.count_nodes();
514                if (previous == count) | (max_iter == 0) {
515                    break;
516                }
517                max_iter -= 1;
518            }
519        } else {
520            loop {
521                previous = count;
522                high_level_graph = leiden_ln(high_level_graph, optimizer);
523                count = high_level_graph.count_nodes();
524                if previous == count {
525                    break;
526                }
527            }
528        }
529
530        high_level_graph
531    }
532}
533
534fn leiden_l1<N: Send + Sync, E: Send + Sync>(
535    graph: &Graph<N, E>,
536    optimizer: &mut impl ModularityOptimizer,
537) -> Graph<Community, ()> {
538    let mut community_assignments = graph.initial_community();
539    graph.optimize_modularity(&mut community_assignments, optimizer);
540    let communities = graph.refine(&community_assignments);
541    return compress_l1(graph, communities);
542}
543
544fn leiden_ln(
545    graph: Graph<Community, ()>,
546    optimizer: &mut impl ModularityOptimizer,
547) -> Graph<Community, ()> {
548    let mut community_assignments = graph.initial_community();
549    graph.optimize_modularity(&mut community_assignments, optimizer);
550    let communities = graph.refine(&community_assignments);
551    if communities.len() == graph._nodes.len() {
552        return graph;
553    }
554    return compress_ln(graph, communities);
555}
556
557fn compress_l1<N, E>(
558    graph: &Graph<N, E>,
559    relabeled_assignments: Vec<HashSet<usize>>,
560) -> Graph<Community, ()> {
561    let mut node_to_community: HashMap<usize, u32> = HashMap::new();
562    let mut new_graph = Graph::new();
563
564    for (i, community) in relabeled_assignments.into_iter().enumerate() {
565        for &node in community.iter() {
566            node_to_community.insert(node, i as u32);
567        }
568
569        let new_community = Community::L1Community(community);
570
571        let node = new_graph.add_node(new_community);
572        debug_assert!(node == i);
573        let _ = node;
574    }
575
576    let node_count = graph.count_nodes();
577    for i in 0..node_count {
578        let assigni = node_to_community[&i];
579
580        for j in i + 1..node_count {
581            let assignj = node_to_community[&j];
582
583            if assigni == assignj {
584                continue;
585            }
586
587            if let Some(edge_info) = graph.try_get_edge_between(i, j) {
588                new_graph.add_edge(assigni as usize, assignj as usize, (), edge_info.weight);
589            }
590        }
591    }
592
593    new_graph
594}
595
596fn compress_ln<E>(
597    mut graph: Graph<Community, E>,
598    relabeled_assignments: Vec<HashSet<usize>>,
599) -> Graph<Community, ()> {
600    let mut node_to_community: HashMap<usize, u32> = HashMap::new();
601    let mut new_graph = Graph::new();
602
603    for (i, community) in relabeled_assignments.into_iter().enumerate() {
604        for &node in community.iter() {
605            node_to_community.insert(node, i as u32);
606        }
607
608        let new_community = if community.len() == 1 {
609            let &c = community.iter().next().unwrap();
610            let x = std::mem::replace(&mut graph._nodes[c], Community::LNCommunity(Vec::new()));
611            #[cfg(debug_assertions)]
612            {
613                if let Community::L1Community(sub_communities) = &x {
614                    debug_assert!(sub_communities.len() >= 1);
615                }
616            }
617            x
618        } else {
619            let sub_communities: Vec<Community> = community
620                .into_iter()
621                .map(|c| {
622                    let x =
623                // graph._nodes is taken out
624                std::mem::replace(&mut graph._nodes[c], Community::LNCommunity(Vec::new()));
625                    x
626                })
627                .collect();
628            debug_assert!(sub_communities.len() >= 1);
629            Community::LNCommunity(sub_communities)
630        };
631
632        let node = new_graph.add_node(new_community);
633        debug_assert!(node == i);
634        let _ = node;
635    }
636
637    let node_count = graph.count_nodes();
638    for i in 0..node_count {
639        let assigni = node_to_community[&i];
640
641        for j in i + 1..node_count {
642            let assignj = node_to_community[&j];
643
644            if assigni == assignj {
645                continue;
646            }
647
648            if let Some(edge_info) = graph.try_get_edge_between(i, j) {
649                new_graph.add_edge(assigni as usize, assignj as usize, (), edge_info.weight);
650            }
651        }
652    }
653
654    new_graph
655}
656
657#[cfg(test)]
658mod tests {
659    use crate::{Graph, TrivialModularityOptimizer};
660    use std::cell::RefCell;
661    use std::collections::{HashMap, HashSet};
662
663    #[test]
664    fn test_example() {
665        let edges: &[(&'static str, &'static str, f32)] = &[
666            ("Fortran", "C", 0.5),
667            ("Fortran", "LISP", 0.3),
668            ("Fortran", "MATLAB", 0.6),
669            ("C", "C++", 0.9),
670            // ("C", "Java", 0.2),
671            ("C", "Go", 0.6),
672            ("LISP", "ML", 0.5),
673            ("LISP", "OCaml", 0.2),
674            ("LISP", "Haskell", 0.2),
675            ("LISP", "Ruby", 0.5),
676            ("LISP", "Julia", 0.6),
677            ("ML", "OCaml", 0.8),
678            ("ML", "Haskell", 0.5),
679            ("OCaml", "Haskell", 0.3),
680            ("OCaml", "F#", 0.6),
681            ("Haskell", "Julia", 0.2),
682            // ("C++", "Java", 0.5),
683            ("C++", "Python", 0.32),
684            ("C++", "Ruby", 0.2),
685            ("C++", "C#", 0.5),
686            // ("Java", "Ruby", 0.4),
687            // ("Java", "Python", 0.5),
688            // ("Java", "C#", 0.6),
689            // ("Java", "Go", 0.45),
690            // ("Java", "Julia", 0.1),
691            ("Python", "F#", 0.2),
692            ("Python", "Julia", 0.4),
693            ("C#", "F#", 0.3),
694        ];
695
696        let mut nodes: HashMap<&'static str, usize> = HashMap::new();
697        let mut g = Graph::new();
698        for (from, to, weight) in edges.iter() {
699            let from_id = *nodes.entry(from).or_insert_with(|| g.add_node(from));
700            let to_id = *nodes.entry(to).or_insert_with(|| g.add_node(to));
701            g.add_edge(from_id, to_id, (), *weight);
702        }
703
704        let mut optimizer = TrivialModularityOptimizer {
705            parallel_scale: 128,
706            tol: 1e-11,
707        };
708
709        let hierarchy = g.leiden(Some(100), &mut optimizer);
710        for (i, node) in hierarchy.node_data_slice().iter().enumerate() {
711            println!("community {}:", i);
712            node.collect_nodes(&|i| {
713                let n = g.node_data_slice()[i];
714                println!("     {}", n);
715            });
716        }
717    }
718
719    #[test]
720    fn test_simplest() {
721        let edges = vec![
722            (1, 2, 1.0),
723            (1, 3, 1.0),
724            (2, 3, 1.0),
725            (4, 5, 1.0),
726            (4, 6, 1.0),
727            (5, 6, 1.0),
728            (7, 8, 1.0),
729            (7, 9, 1.0),
730            (8, 9, 1.0),
731        ];
732
733        let mut nodes: HashMap<usize, usize> = HashMap::new();
734        let mut g = Graph::new();
735        for (from, to, weight) in edges.into_iter() {
736            let from_id = *nodes.entry(from).or_insert_with(|| g.add_node(from));
737            let to_id = *nodes.entry(to).or_insert_with(|| g.add_node(to));
738            g.add_edge(from_id, to_id, (), weight);
739        }
740
741        let mut optimizer = TrivialModularityOptimizer {
742            parallel_scale: 128,
743            tol: 1e-13,
744        };
745
746        let assignments = RefCell::new(g.initial_community());
747
748        let hierarchy = g.leiden(Some(100), &mut optimizer);
749        for (i, node) in hierarchy.node_data_slice().iter().enumerate() {
750            println!("community {}:", i);
751            let comm = i;
752            node.collect_nodes(&|i| {
753                assignments.borrow_mut().insert(i, comm as u32);
754                let n = g.node_data_slice()[i];
755                println!("     {}", n);
756            });
757        }
758
759        assert!(assignments.borrow().values().collect::<HashSet<_>>().len() == 3);
760
761        println!(
762            "real modularity: {}",
763            g.compute_modularity(&assignments.borrow())
764        );
765    }
766}