duskphantom_middle/analysis/
loop_tools.rs

1// Copyright 2024 Duskphantom Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17use dominator_tree::DominatorTree;
18
19use duskphantom_utils::mem::{ObjPool, ObjPtr};
20
21use super::*;
22
23pub type LoopPtr = ObjPtr<LoopTree>;
24
25#[allow(dead_code)]
26pub struct LoopTree {
27    pub pre_header: Option<BBPtr>,
28    pub head: BBPtr,
29    pub blocks: HashSet<BBPtr>,
30    pub parent_loop: Option<LoopPtr>,
31    pub sub_loops: Vec<LoopPtr>,
32}
33
34#[allow(dead_code)]
35impl LoopTree {
36    pub fn is_in_cur_loop(&self, bb: &BBPtr) -> bool {
37        self.blocks.contains(bb)
38    }
39
40    pub fn is_in_loop(&self, bb: &BBPtr) -> bool {
41        self.blocks.contains(bb) || self.sub_loops.iter().any(|lo| lo.is_in_loop(bb))
42    }
43}
44
45#[allow(dead_code)]
46pub struct LoopForest {
47    pool: ObjPool<LoopTree>,
48    // 只包含最外层循环,内部循环通过LoopPtr的sub_loops域访问
49    pub forest: Vec<LoopPtr>,
50}
51
52#[allow(dead_code)]
53impl LoopForest {
54    // tarjan算法变体
55    pub fn make_forest(func: FunPtr) -> Option<LoopForest> {
56        let mut domin = DominatorTree::new(func);
57        let mut stack;
58        if let Some(x) = func.entry {
59            stack = vec![x];
60        } else {
61            return None;
62        }
63
64        let mut loop_tree_meta = Vec::new();
65        let mut id_map = HashMap::new();
66
67        while let Some(bb) = stack.pop() {
68            if let std::collections::hash_map::Entry::Vacant(e) = id_map.entry(bb) {
69                // 第一次遍历该结点
70
71                // 初始化
72                e.insert(HashSet::from([loop_tree_meta.len()]));
73                loop_tree_meta.push((bb, HashSet::from([bb]), None::<BBPtr>));
74                stack.push(bb);
75
76                // dfs
77                if let Some(next) = bb.get_succ_bb().first() {
78                    if !id_map.contains_key(next) {
79                        stack.push(*next);
80                    }
81                }
82            } else if bb
83                .get_succ_bb()
84                .iter()
85                .all(|next| id_map.contains_key(next))
86            {
87                // bb的所有分支均已访问
88
89                // 合并下游bb序号集中小于当前bb的序号
90                let cur_id = *id_map.get(&bb).unwrap().iter().next().unwrap();
91                let cur_map: HashSet<usize> = bb
92                    .get_succ_bb()
93                    .iter()
94                    .map(|x| id_map.get(x).unwrap().clone())
95                    .reduce(|acc, e| &acc | &e)
96                    .unwrap_or(HashSet::new())
97                    .into_iter()
98                    .filter(|x| *x <= cur_id && domin.is_dominate(loop_tree_meta[*x].0, bb))
99                    .collect();
100
101                // 获取当前序号集中最大两个数(如有)
102                let mut max_two = [-1, -1];
103                cur_map.iter().for_each(|&x| {
104                    let x = x as i32;
105                    if x > max_two[0] {
106                        max_two[1] = max_two[0];
107                        max_two[0] = x;
108                    } else if x > max_two[1] {
109                        max_two[1] = x;
110                    }
111                });
112                id_map.insert(bb, cur_map).unwrap();
113                id_map.get_mut(&bb).unwrap().insert(cur_id);
114
115                // 最大的数为当前循环的head bb的id
116                // 次之为父循环的head bb的id
117                // 依次类推
118                match max_two {
119                    [-1, -1] => {}
120                    [x, -1] => {
121                        loop_tree_meta[x as usize].1.insert(bb);
122                    }
123                    [x, y] => {
124                        loop_tree_meta[x as usize].1.insert(bb);
125                        if x as usize == cur_id {
126                            loop_tree_meta[x as usize].2 = Some(loop_tree_meta[y as usize].0);
127                        }
128                    }
129                }
130            } else {
131                // 当前bb有双分支,且第二个分支还未访问
132                stack.push(bb);
133                stack.push(bb.get_succ_bb()[1]);
134            }
135        }
136
137        let mut forest = LoopForest {
138            pool: ObjPool::new(),
139            forest: Vec::new(),
140        };
141        let mut forest_map = HashMap::new();
142
143        for (head, blocks, parent_loop) in loop_tree_meta.into_iter() {
144            let loop_ptr = forest.pool.alloc(LoopTree {
145                pre_header: None,
146                head,
147                blocks,
148                parent_loop: parent_loop.and_then(|x| forest_map.get(&x).cloned()),
149                sub_loops: Vec::new(),
150            });
151            if let Some(par) = loop_ptr.parent_loop {
152                forest_map
153                    .get_mut(&par.head)
154                    .iter_mut()
155                    .for_each(|x| x.sub_loops.push(loop_ptr));
156            } else {
157                forest.forest.push(loop_ptr);
158            };
159            forest_map.insert(head, loop_ptr);
160        }
161
162        forest.forest.retain(|x| {
163            x.blocks.len() > 1
164                || !x.sub_loops.is_empty()
165                || x.head.get_succ_bb().iter().any(|succ| *succ == x.head)
166        });
167
168        Some(forest)
169    }
170}
171
172#[cfg(test)]
173mod test_loop {
174    use crate::ir::IRBuilder;
175    use std::iter::zip;
176
177    use super::*;
178
179    fn gen_graph(pool: &mut IRBuilder, down_stream: Vec<[i32; 2]>) -> (FunPtr, Vec<BBPtr>) {
180        let bb_vec: Vec<BBPtr> = (0..down_stream.len())
181            .map(|_| pool.new_basicblock("no_name".to_string()))
182            .collect();
183        for (mut bb, down) in zip(bb_vec.clone(), down_stream) {
184            match down {
185                [-1, -1] => {}
186                [t, -1] => bb.set_true_bb(bb_vec[t as usize]),
187                [t, f] => {
188                    bb.set_true_bb(bb_vec[t as usize]);
189                    bb.set_false_bb(bb_vec[f as usize]);
190                }
191            }
192        }
193        let mut func = pool.new_function("no_name".to_string(), crate::ir::ValueType::Void);
194        func.entry = Some(bb_vec[0]);
195        func.exit = Some(bb_vec[bb_vec.len() - 1]);
196        (func, bb_vec)
197    }
198
199    fn gen_forest(bb_vec: Vec<[i32; 2]>) -> (IRBuilder, Option<LoopForest>, Vec<BBPtr>) {
200        let mut pool = IRBuilder::new();
201        let (func, bb_vec) = gen_graph(&mut pool, bb_vec);
202        let forest = LoopForest::make_forest(func);
203        (pool, forest, bb_vec)
204    }
205
206    #[test]
207    fn one_loop() {
208        let (_pool, forest, bb_vec) = gen_forest(vec![[1, -1], [2, 3], [1, -1], [-1, -1]]);
209        assert!(forest.is_some());
210
211        let forest = forest.unwrap();
212        assert_eq!(forest.forest.len(), 1);
213
214        let lo = forest.forest.first().unwrap();
215        assert!(!lo.is_in_loop(&bb_vec[0]));
216        assert!(lo.is_in_loop(&bb_vec[1]));
217        assert!(lo.is_in_loop(&bb_vec[2]));
218        assert!(!lo.is_in_loop(&bb_vec[3]));
219    }
220
221    #[test]
222    fn two_loop() {
223        let (_pool, forest, bb_vec) = gen_forest(vec![[1, -1], [0, 2], [3, -1], [2, 4], [-1, -1]]);
224        assert!(forest.is_some());
225
226        let forest = forest.unwrap();
227        assert_eq!(forest.forest.len(), 2);
228
229        let first = forest.forest[0];
230        assert!(first.is_in_loop(&bb_vec[0]));
231        assert!(first.is_in_loop(&bb_vec[1]));
232        assert!(!first.is_in_loop(&bb_vec[2]));
233        assert!(!first.is_in_loop(&bb_vec[4]));
234
235        let second = forest.forest[1];
236        assert!(second.is_in_loop(&bb_vec[2]));
237        assert!(second.is_in_loop(&bb_vec[3]));
238        assert!(!second.is_in_loop(&bb_vec[4]));
239    }
240
241    #[test]
242    fn conponent_loop() {
243        let (_pool, forest, bb_vec) = gen_forest(vec![
244            [1, -1],
245            [0, 2],
246            [0, 3],
247            [4, -1],
248            [5, -1],
249            [4, 6],
250            [4, -1],
251        ]);
252
253        assert!(forest.is_some());
254        let forest = forest.unwrap();
255        assert_eq!(forest.forest.len(), 2);
256
257        let first = forest.forest[0];
258        assert_eq!(first.blocks.len(), 3);
259        assert!(first.is_in_loop(&bb_vec[0]));
260        assert!(first.is_in_loop(&bb_vec[1]));
261        assert!(first.is_in_loop(&bb_vec[2]));
262
263        let second = forest.forest[1];
264        assert_eq!(second.blocks.len(), 3);
265        assert!(second.is_in_loop(&bb_vec[4]));
266        assert!(second.is_in_loop(&bb_vec[5]));
267        assert!(second.is_in_loop(&bb_vec[6]));
268    }
269
270    #[test]
271    fn branch_loop() {
272        let (_pool, forest, bb_vec) = gen_forest(vec![[1, 2], [3, -1], [3, -1], [0, -1]]);
273
274        assert!(forest.is_some());
275        let forest = forest.unwrap();
276        assert_eq!(forest.forest.len(), 1);
277
278        let lo = forest.forest[0];
279        assert_eq!(lo.blocks.len(), 4);
280        assert_eq!(lo.sub_loops.len(), 0);
281        assert!(lo.parent_loop.is_none());
282        assert!(lo.is_in_cur_loop(&bb_vec[0]));
283        assert!(lo.is_in_cur_loop(&bb_vec[1]));
284        assert!(lo.is_in_cur_loop(&bb_vec[2]));
285        assert!(lo.is_in_cur_loop(&bb_vec[3]));
286    }
287
288    #[test]
289    fn nested_loop() {
290        let (_pool, forest, bb_vec) = gen_forest(vec![[1, -1], [2, -1], [0, 1]]);
291        assert!(forest.is_some());
292        let forest = forest.unwrap();
293        assert_eq!(forest.forest.len(), 1);
294        let lo = forest.forest[0];
295        assert_eq!(lo.blocks.len(), 1);
296        assert_eq!(lo.head, bb_vec[0]);
297        assert!(lo.is_in_cur_loop(&bb_vec[0]));
298
299        assert_eq!(lo.sub_loops.len(), 1);
300        let sub = lo.sub_loops[0];
301        assert_eq!(sub.blocks.len(), 2);
302        assert_eq!(sub.head, bb_vec[1]);
303        assert!(sub.is_in_cur_loop(&bb_vec[1]));
304        assert!(sub.is_in_cur_loop(&bb_vec[2]));
305    }
306
307    #[test]
308    fn nested_branch_loop() {
309        let (_pool, forest, bb_vec) =
310            gen_forest(vec![[1, 5], [2, 3], [4, -1], [4, -1], [1, 0], [-1, -1]]);
311        assert!(forest.is_some());
312        let forest = forest.unwrap();
313        assert_eq!(forest.forest.len(), 1);
314        let lo = forest.forest[0];
315        assert_eq!(lo.head, bb_vec[0]);
316        assert!(lo.is_in_cur_loop(&bb_vec[0]));
317        assert!(lo.is_in_loop(&bb_vec[1]));
318        assert!(lo.is_in_loop(&bb_vec[2]));
319        assert!(lo.is_in_loop(&bb_vec[3]));
320        assert!(lo.is_in_loop(&bb_vec[4]));
321        assert_eq!(lo.blocks.len(), 1);
322        assert_eq!(lo.sub_loops.len(), 1);
323
324        let sub = lo.sub_loops[0];
325        assert_eq!(sub.blocks.len(), 4);
326        assert!(sub.is_in_loop(&bb_vec[1]));
327        assert!(sub.is_in_loop(&bb_vec[2]));
328        assert!(sub.is_in_loop(&bb_vec[3]));
329        assert!(sub.is_in_loop(&bb_vec[4]));
330    }
331}