pb_rs/
scc.rs

1use crate::types::{FileDescriptor, Frequency, MessageIndex};
2use std::cmp::min;
3use std::collections::HashMap;
4
5/// A recursive strongly connected component function
6///
7/// Uses Tarjan's algorithm
8/// https://www.geeksforgeeks.org/tarjan-algorithm-find-strongly-connected-components/
9fn scc(
10    vertices: &[MessageIndex],
11    desc: &FileDescriptor,
12    u: usize,
13    count: &mut isize,
14    low: &mut [isize],
15    disc: &mut [isize],
16    stack: &mut Vec<usize>,
17    sccs: &mut Vec<Vec<usize>>,
18    ids: &HashMap<MessageIndex, usize>,
19) {
20    disc[u] = *count;
21    low[u] = *count;
22    *count += 1;
23    stack.push(u);
24
25    for &v in vertices[u]
26        .get_message(desc)
27        .all_fields()
28        .filter(|f| !f.boxed && f.frequency != Frequency::Repeated)
29        .filter_map(|f| f.typ.message())
30        .filter_map(|m| ids.get(m))
31    {
32        if disc[v] == -1 {
33            scc(vertices, desc, v, count, low, disc, stack, sccs, ids);
34            low[u] = min(low[u], low[v]);
35        } else if stack.contains(&v) {
36            low[u] = min(low[u], disc[v]);
37        }
38    }
39
40    if low[u] == disc[u] {
41        let mut scc = Vec::new();
42        while let Some(w) = stack.pop() {
43            scc.push(w);
44            if w == u {
45                break;
46            }
47        }
48        sccs.push(scc);
49    }
50}
51
52impl FileDescriptor {
53    fn flatten_messages(&self) -> Vec<MessageIndex> {
54        let mut all_msgs = self
55            .messages
56            .iter()
57            .map(|m| m.index.clone())
58            .collect::<Vec<_>>();
59        let mut vertices = Vec::with_capacity(all_msgs.len());
60        while let Some(m) = all_msgs.pop() {
61            all_msgs.extend(m.get_message(self).messages.iter().map(|m| m.index.clone()));
62            vertices.push(m);
63        }
64        vertices
65    }
66
67    pub fn sccs(&self) -> Vec<Vec<MessageIndex>> {
68        let vertices = self.flatten_messages();
69        let ids = vertices
70            .iter()
71            .enumerate()
72            .map(|(i, m)| (m.get_message(self).index.clone(), i))
73            .collect::<HashMap<_, _>>();
74        let mut low = vec![-1; vertices.len()];
75        let mut disc = vec![-1; vertices.len()];
76        let mut stack: Vec<usize> = Vec::new();
77        let mut count = 0isize;
78        let mut sccs: Vec<Vec<usize>> = Vec::new();
79        for u in 0..vertices.len() {
80            if disc[u] == -1 {
81                scc(
82                    &vertices, self, u, &mut count, &mut low, &mut disc, &mut stack, &mut sccs,
83                    &ids,
84                );
85            }
86        }
87        sccs.into_iter()
88            .map(|scc| scc.into_iter().map(|i| vertices[i].clone()).collect())
89            .collect()
90    }
91}