sparta/
wpo.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8use std::collections::BTreeSet;
9use std::collections::HashMap;
10use std::collections::HashSet;
11use std::fmt::Debug;
12use std::hash::Hash;
13use std::iter::FromIterator;
14
15use petgraph::unionfind::UnionFind;
16
17use crate::graph::SuccessorNodes;
18
19pub type WpoIdx = u32;
20
21pub struct WpoNodeData<NodeId: Copy + Hash + Ord> {
22    node: NodeId,
23    size: usize,
24    successors: BTreeSet<WpoIdx>,
25    predessors: BTreeSet<WpoIdx>,
26    num_outer_preds: HashMap<WpoIdx, u32>,
27}
28
29impl<NodeId> WpoNodeData<NodeId>
30where
31    NodeId: Copy + Hash + Ord,
32{
33    pub fn new(node: NodeId, size: usize) -> Self {
34        Self {
35            node,
36            size,
37            successors: BTreeSet::new(),
38            predessors: BTreeSet::new(),
39            num_outer_preds: HashMap::new(),
40        }
41    }
42}
43
44#[derive(Debug, PartialEq, Eq)]
45pub enum WpoNodeType {
46    Head,
47    Plain,
48    Exit,
49}
50
51pub struct WpoNode<NodeId: Copy + Ord + Hash> {
52    ty: WpoNodeType,
53    data: WpoNodeData<NodeId>,
54}
55
56impl<NodeId> WpoNode<NodeId>
57where
58    NodeId: Copy + Ord + Hash + Debug,
59{
60    pub fn plain(node: NodeId, size: usize) -> Self {
61        Self {
62            ty: WpoNodeType::Plain,
63            data: WpoNodeData::new(node, size),
64        }
65    }
66
67    pub fn head(node: NodeId, size: usize) -> Self {
68        Self {
69            ty: WpoNodeType::Head,
70            data: WpoNodeData::new(node, size),
71        }
72    }
73
74    pub fn exit(node: NodeId, size: usize) -> Self {
75        Self {
76            ty: WpoNodeType::Exit,
77            data: WpoNodeData::new(node, size),
78        }
79    }
80
81    pub fn new(ty: WpoNodeType, node: NodeId, size: usize) -> Self {
82        Self {
83            ty,
84            data: WpoNodeData::new(node, size),
85        }
86    }
87
88    pub fn is_plain(&self) -> bool {
89        self.ty == WpoNodeType::Plain
90    }
91
92    pub fn is_head(&self) -> bool {
93        self.ty == WpoNodeType::Head
94    }
95
96    pub fn is_exit(&self) -> bool {
97        self.ty == WpoNodeType::Exit
98    }
99
100    pub fn get_node(&self) -> NodeId {
101        self.data.node
102    }
103
104    pub fn get_successors(&self) -> &BTreeSet<WpoIdx> {
105        &self.data.successors
106    }
107
108    pub fn get_predecessors(&self) -> &BTreeSet<WpoIdx> {
109        &self.data.predessors
110    }
111
112    pub fn get_num_preds(&self) -> u32 {
113        self.get_predecessors().len() as u32
114    }
115
116    pub fn get_num_outer_preds(&self) -> &HashMap<WpoIdx, u32> {
117        assert_eq!(
118            self.ty,
119            WpoNodeType::Exit,
120            "Node {:#?} is not Exit",
121            self.data.node
122        );
123        &self.data.num_outer_preds
124    }
125
126    pub fn get_size(&self) -> usize {
127        self.data.size
128    }
129
130    fn add_successor(&mut self, idx: WpoIdx) {
131        self.data.successors.insert(idx);
132    }
133
134    fn add_predecessor(&mut self, idx: WpoIdx) {
135        self.data.predessors.insert(idx);
136    }
137
138    fn is_successor(&self, idx: WpoIdx) -> bool {
139        self.get_successors().contains(&idx)
140    }
141
142    pub fn inc_num_outer_preds(&mut self, idx: WpoIdx) {
143        assert_eq!(
144            self.ty,
145            WpoNodeType::Exit,
146            "Node {:#?} is not Exit",
147            self.data.node
148        );
149        *self.data.num_outer_preds.entry(idx).or_default() += 1;
150    }
151}
152
153pub struct WeakPartialOrdering<NodeId: Copy + Hash + Ord> {
154    /// All nodes under WPO.
155    nodes: Vec<WpoNode<NodeId>>,
156    /// All top level nodes.
157    toplevel: Vec<WpoIdx>,
158    /// Post depth first numbering for each node.
159    post_dfn: HashMap<NodeId, u32>,
160}
161
162impl<NodeId> WeakPartialOrdering<NodeId>
163where
164    NodeId: Copy + Hash + Ord + Debug,
165{
166    pub fn new<SN>(root: NodeId, size: usize, successors_nodes: &SN) -> Self
167    where
168        SN: SuccessorNodes<NodeId = NodeId>,
169    {
170        if successors_nodes.get_succ_nodes(root).is_empty() {
171            let mut wpo = Self {
172                nodes: vec![],
173                toplevel: vec![],
174                post_dfn: HashMap::new(),
175            };
176            wpo.nodes.push(WpoNode::plain(root, 1));
177            wpo.toplevel.push(0);
178            wpo.post_dfn.insert(root, 1);
179            wpo
180        } else {
181            WeakPartialOrderingImpl::new().build(size, root, successors_nodes)
182        }
183    }
184
185    pub fn size(&self) -> usize {
186        self.nodes.len()
187    }
188
189    pub fn get_entry(&self) -> WpoIdx {
190        (self.nodes.len() - 1) as WpoIdx
191    }
192
193    pub fn get_successors(&self, idx: WpoIdx) -> &BTreeSet<WpoIdx> {
194        self.nodes[idx as usize].get_successors()
195    }
196
197    pub fn get_predecessors(&self, idx: WpoIdx) -> &BTreeSet<WpoIdx> {
198        self.nodes[idx as usize].get_predecessors()
199    }
200
201    pub fn get_num_preds(&self, idx: WpoIdx) -> u32 {
202        self.nodes[idx as usize].get_num_preds()
203    }
204
205    pub fn get_num_outer_preds(&self, exit: WpoIdx) -> &HashMap<WpoIdx, u32> {
206        self.nodes[exit as usize].get_num_outer_preds()
207    }
208
209    pub fn get_head_of_exit(&self, exit: WpoIdx) -> WpoIdx {
210        exit + 1
211    }
212
213    pub fn get_exit_of_head(&self, head: WpoIdx) -> WpoIdx {
214        head - 1
215    }
216
217    pub fn get_node(&self, idx: WpoIdx) -> NodeId {
218        self.nodes[idx as usize].get_node()
219    }
220
221    pub fn is_plain(&self, idx: WpoIdx) -> bool {
222        self.nodes[idx as usize].is_plain()
223    }
224
225    pub fn is_head(&self, idx: WpoIdx) -> bool {
226        self.nodes[idx as usize].is_head()
227    }
228
229    pub fn is_exit(&self, idx: WpoIdx) -> bool {
230        self.nodes[idx as usize].is_exit()
231    }
232
233    pub fn is_from_outside(&self, head: NodeId, pred: NodeId) -> bool {
234        self.get_post_dfn(head) < self.get_post_dfn(pred)
235    }
236
237    fn get_post_dfn(&self, n: NodeId) -> u32 {
238        // If the key does not exist, meaning that node is not
239        // finished yet, return default value 0.
240        self.post_dfn.get(&n).copied().unwrap_or_default()
241    }
242}
243
244// This private type is only used to build the actual WPO.
245struct WeakPartialOrderingImpl<NodeId: Copy + Hash + Ord> {
246    nodes: Vec<WpoNode<NodeId>>,
247    toplevel: Vec<WpoIdx>,
248    post_dfn: HashMap<NodeId, u32>,
249    // A map from NodeId to post DFN.
250    dfn: HashMap<NodeId, u32>,
251    dfn_to_node: Vec<NodeId>,
252    cross_fwd_edges: HashMap<u32, Vec<(u32, u32)>>,
253    back_preds: HashMap<u32, Vec<u32>>,
254    // Tree edges (map from node to its predecessors).
255    non_back_preds: HashMap<u32, Vec<u32>>,
256    next_dfn: u32,
257    // Map from dfn to WpoIdx
258    dfn_to_wpo_idx: Vec<WpoIdx>,
259    // Next WpoIdx to assign
260    next_idx: WpoIdx,
261}
262
263impl<NodeId> WeakPartialOrderingImpl<NodeId>
264where
265    NodeId: Copy + Hash + Ord + Debug,
266{
267    pub fn new() -> Self {
268        // I really don't want to add `Default` bound to `NodeId`, so let's
269        // have a bit tedious code here to give user side more flexibility.
270        Self {
271            next_dfn: 1u32,
272            nodes: vec![],
273            toplevel: vec![],
274            post_dfn: HashMap::new(),
275            dfn: HashMap::new(),
276            dfn_to_node: vec![],
277            cross_fwd_edges: HashMap::new(),
278            back_preds: HashMap::new(),
279            non_back_preds: HashMap::new(),
280            dfn_to_wpo_idx: vec![],
281            next_idx: 0,
282        }
283    }
284
285    fn add_node(&mut self, dfn_i: u32, vertex: u32, sz: u32, ty: WpoNodeType) {
286        self.dfn_to_wpo_idx[dfn_i as usize] = self.next_idx;
287        self.next_idx += 1;
288        self.nodes.push(WpoNode::new(
289            ty,
290            // dfn reserves 0, so should subtract 1 here.
291            self.dfn_to_node[vertex as usize - 1],
292            sz as usize,
293        ));
294    }
295
296    fn node_of(&mut self, dfn_i: u32) -> &mut WpoNode<NodeId> {
297        let idx = self.index_of(dfn_i) as usize;
298        &mut self.nodes[idx]
299    }
300
301    fn index_of(&self, dfn_i: u32) -> u32 {
302        self.dfn_to_wpo_idx[dfn_i as usize]
303    }
304
305    fn add_successor(
306        &mut self,
307        from: u32,
308        to: u32,
309        exit: u32,
310        outer_pred: bool,
311        for_outer_preds: &mut Vec<(WpoIdx, WpoIdx)>,
312    ) {
313        let from_idx = self.index_of(from);
314        let to_idx = self.index_of(to);
315        if !self.nodes[from_idx as usize].is_successor(to_idx) {
316            if outer_pred {
317                for_outer_preds.push((to_idx, self.index_of(exit)));
318            }
319            self.nodes[from_idx as usize].add_successor(to_idx);
320            self.nodes[to_idx as usize].add_predecessor(from_idx);
321        }
322    }
323
324    fn build_auxilary<SN>(&mut self, size: usize, root: NodeId, successors_nodes: &SN)
325    where
326        SN: SuccessorNodes<NodeId = NodeId>,
327    {
328        // Since 0 is reserved for undiscovered nodes, the total number of nodes
329        // would be size + 1.
330        let mut dft_dsets = UnionFind::<u32>::new(size + 1);
331        let mut stack = Vec::new();
332        let mut next_post_dfn = 1u32;
333        let mut visited = HashMap::new();
334        let mut ancestor = HashMap::new();
335
336        let get_dfn = |n: NodeId, dfn: &HashMap<NodeId, u32>| {
337            // If the key does not exist, meaning that node is not
338            // discovered yet, return default value 0.
339            dfn.get(&n).copied().unwrap_or_default()
340        };
341        let set_dfn = |n: NodeId, num: u32, dfn: &mut HashMap<NodeId, u32>| {
342            dfn.insert(n, num);
343        };
344
345        stack.push((root, false, 0u32));
346
347        while let Some((node, finished, pred)) = stack.pop() {
348            if finished {
349                self.post_dfn.insert(node, next_post_dfn);
350                next_post_dfn += 1;
351
352                let vertex = get_dfn(node, &self.dfn);
353                visited.insert(vertex, true);
354
355                dft_dsets.union(vertex, pred);
356                ancestor.insert(dft_dsets.find_mut(pred), pred);
357            } else {
358                if get_dfn(node, &self.dfn) != 0 {
359                    // Skip forward edges.
360                    continue;
361                }
362
363                let vertex = self.next_dfn;
364                self.next_dfn += 1;
365                self.dfn_to_node.push(node);
366                set_dfn(node, vertex, &mut self.dfn);
367                ancestor.insert(vertex, vertex);
368
369                stack.push((node, true, pred));
370
371                let successors = successors_nodes.get_succ_nodes(node);
372                for &succ_node in successors.iter().rev() {
373                    let succ = get_dfn(succ_node, &self.dfn);
374                    if 0 == succ {
375                        stack.push((succ_node, false, vertex));
376                    } else if visited.get(&succ).copied().unwrap_or_default() {
377                        let lca = ancestor.get(&dft_dsets.find_mut(succ)).copied().unwrap();
378                        self.cross_fwd_edges
379                            .entry(lca)
380                            .or_default()
381                            .push((vertex, succ));
382                    } else {
383                        self.back_preds.entry(succ).or_default().push(vertex);
384                    }
385                }
386
387                if pred != 0 {
388                    self.non_back_preds.entry(vertex).or_default().push(pred);
389                }
390            }
391        }
392
393        // Number of dfn should be equal or smaller (if there is unreachable node)
394        // than grpah size + 1 (number 0 for undiscovered).
395        assert!(self.next_dfn as usize <= size + 1);
396    }
397
398    fn build<SN>(
399        mut self,
400        size: usize,
401        root: NodeId,
402        successors_nodes: &SN,
403    ) -> WeakPartialOrdering<NodeId>
404    where
405        SN: SuccessorNodes<NodeId = NodeId>,
406    {
407        // Step 1: construct auxilary data structures, including
408        // classifying edges, finding lowest common ancestors
409        // of cross/forward edges.
410        self.build_auxilary(size, root, successors_nodes);
411
412        // Step 2: start constructing WPO.
413        let mut dsets = UnionFind::<u32>::new(self.next_dfn as usize);
414        // Union find does not guarantee that the root of a subset has
415        // always the minimum DFN, so we need to maintain this information.
416        // Used for creating exit nodes.
417        let mut exit_next_dfn = self.next_dfn;
418        // Initializaiton.
419        let mut rep: Vec<u32> = (0..self.next_dfn).collect();
420        let mut exit: Vec<u32> = (0..self.next_dfn).collect();
421        let mut origin: Vec<Vec<(u32, u32)>> = (0..self.next_dfn)
422            .map(|v| {
423                self.non_back_preds
424                    .get(&v)
425                    .map_or_else(std::vec::Vec::new, |non_back_preds_v| {
426                        non_back_preds_v.iter().map(|&p| (p, v)).collect()
427                    })
428            })
429            .collect();
430
431        self.dfn_to_wpo_idx.resize(2 * self.next_dfn as usize, 0);
432        let mut for_outer_preds = Vec::<(WpoIdx, WpoIdx)>::new();
433        let mut components_sizes = vec![0u32; self.next_dfn as usize];
434        let mut parent = HashMap::<WpoIdx, WpoIdx>::new();
435
436        // In descending order, excluding 0 which is for undiscovered.
437        for h in (1..self.next_dfn).rev() {
438            // Restore cross/forward edges
439            if let Some(edges) = self.cross_fwd_edges.get(&h) {
440                for &(u, v) in edges {
441                    let rep_v = rep[dsets.find(v) as usize];
442                    self.non_back_preds.entry(rep_v).or_default().push(u);
443                    origin[rep_v as usize].push((u, v));
444                }
445            }
446
447            // Find nested SCCs.
448            let mut is_scc = false;
449            let mut backpreds_h = HashSet::<u32>::new();
450            if let Some(preds) = self.back_preds.get(&h) {
451                for &v in preds {
452                    if v != h {
453                        backpreds_h.insert(rep[dsets.find(v) as usize]);
454                    } else {
455                        is_scc = true;
456                    }
457                }
458            }
459
460            if !backpreds_h.is_empty() {
461                is_scc = true;
462            }
463
464            let mut nested_sccs_h = backpreds_h.clone();
465            let mut worklist_h = Vec::from_iter(backpreds_h.iter().copied());
466            while let Some(v) = worklist_h.pop() {
467                if let Some(preds) = self.non_back_preds.get(&v) {
468                    for &p in preds {
469                        let rep_p = rep[dsets.find(p) as usize];
470                        if !nested_sccs_h.contains(&rep_p) && rep_p != h {
471                            worklist_h.push(rep_p);
472                            nested_sccs_h.insert(rep_p);
473                        }
474                    }
475                }
476            }
477
478            // h represents a trivial SCC.
479            if !is_scc {
480                components_sizes[h as usize] = 1;
481                self.add_node(h, h, 1, WpoNodeType::Plain);
482                continue;
483            }
484
485            // Initialize size to 2 for head and exit.
486            let mut sz_h = 2;
487            for &v in nested_sccs_h.iter() {
488                sz_h += components_sizes[v as usize];
489            }
490            components_sizes[h as usize] = sz_h;
491
492            // Add new exit.
493            let x_h = exit_next_dfn;
494            exit_next_dfn += 1;
495            self.add_node(x_h, h, sz_h, WpoNodeType::Exit);
496            // Wpo index of head is then exit + 1 for the same component.
497            self.add_node(h, h, sz_h, WpoNodeType::Head);
498
499            if backpreds_h.is_empty() {
500                // Scheduling constraints from h to x_h.
501                self.add_successor(h, x_h, x_h, false, &mut for_outer_preds);
502            } else {
503                for p in backpreds_h {
504                    self.add_successor(exit[p as usize], x_h, x_h, false, &mut for_outer_preds);
505                }
506            }
507
508            // Scheduling constraints between WPOs for nested SCCs.
509            for &v in nested_sccs_h.iter() {
510                for &(u, vv) in origin[v as usize].iter() {
511                    let x_u = exit[rep[dsets.find(u) as usize] as usize];
512                    let x_v = exit[v as usize];
513                    self.add_successor(x_u, vv, x_v, x_v != v, &mut for_outer_preds);
514                }
515            }
516
517            // Merging all reps in nested SCCs to h
518            for &v in nested_sccs_h.iter() {
519                dsets.union(v, h);
520                rep[dsets.find(v) as usize] = h;
521                parent.insert(self.index_of(v), self.index_of(h));
522            }
523
524            exit[h as usize] = x_h;
525        }
526
527        // Scheduling constraints between WPOs for maximal SCCs.
528        self.toplevel.reserve(self.next_dfn as usize);
529        for v in 1..self.next_dfn {
530            if rep[dsets.find(v) as usize] == v {
531                let v_idx = self.index_of(v);
532                self.toplevel.push(v_idx);
533                parent.insert(v_idx, v_idx);
534
535                for &(u, vv) in origin[v as usize].iter() {
536                    let x_u = exit[rep[dsets.find(u) as usize] as usize];
537                    let x_v = exit[v as usize];
538                    self.add_successor(x_u, vv, x_v, x_v != v, &mut for_outer_preds);
539                }
540            }
541        }
542
543        // Compute num_outer_preds.
544        for &(v, x_max) in for_outer_preds.iter() {
545            let mut h = if self.nodes[v as usize].is_head() {
546                v
547            } else {
548                *parent.get(&v).unwrap()
549            };
550            let mut x = h - 1;
551            while x != x_max {
552                self.nodes[x as usize].inc_num_outer_preds(v);
553                h = *parent.get(&h).unwrap();
554                x = h - 1;
555            }
556            self.nodes[x as usize].inc_num_outer_preds(v);
557        }
558
559        WeakPartialOrdering {
560            nodes: self.nodes,
561            toplevel: self.toplevel,
562            post_dfn: self.post_dfn,
563        }
564    }
565}