duskphantom_middle/transform/
loop_simplify.rs

1// Copyright 2024 Duskphantom Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17use std::pin::Pin;
18
19use anyhow::{Ok, Result};
20
21use crate::{
22    analysis::loop_tools::{LoopForest, LoopPtr},
23    ir::{
24        instruction::{downcast_mut, misc_inst::Phi, InstType},
25        Instruction, Operand,
26    },
27    transform::loop_optimization::loop_forest_post_order,
28    IRBuilder,
29};
30
31type IRBuilderWraper = Pin<Box<IRBuilder>>;
32
33pub struct LoopSimplifier<'a> {
34    ir_builder: &'a mut IRBuilderWraper,
35}
36
37impl<'a> LoopSimplifier<'a> {
38    pub fn new(ir_builder: &'a mut IRBuilderWraper) -> LoopSimplifier {
39        Self { ir_builder }
40    }
41
42    pub fn run(&mut self, loop_forest: &mut LoopForest) -> Result<()> {
43        loop_forest_post_order(loop_forest, |x| self.simplify_one_loop(x))
44    }
45
46    fn simplify_one_loop(&mut self, lo: LoopPtr) -> Result<()> {
47        if lo.pre_header.is_none() {
48            self.insert_preheader(lo)?;
49        }
50
51        self.insert_unique_backedge_block(lo)?;
52        Ok(())
53    }
54
55    fn insert_unique_backedge_block(&mut self, mut lo: LoopPtr) -> Result<()> {
56        let head = lo.head;
57        let backedge_blocks_index = head
58            .get_pred_bb()
59            .iter()
60            .enumerate()
61            .filter_map(|(index, &bb)| {
62                if bb != lo.pre_header.unwrap() {
63                    Some(index)
64                } else {
65                    None
66                }
67            })
68            .collect::<Vec<_>>();
69
70        if backedge_blocks_index.len() == 1 {
71            return Ok(());
72        }
73
74        let mut unique_backedge_block = self
75            .ir_builder
76            .new_basicblock("uni_backedge_".to_owned() + &lo.head.name);
77        let mut tail = self.ir_builder.get_br(None);
78        unique_backedge_block.push_back(tail);
79
80        let mut inst = head.get_first_inst();
81        while let InstType::Phi = inst.get_type() {
82            let phi = downcast_mut::<Phi>(inst.as_mut());
83
84            let incoming_values = backedge_blocks_index
85                .iter()
86                .map(|index| phi.get_incoming_values()[*index].clone())
87                .collect::<Vec<_>>();
88
89            let new_phi = self
90                .ir_builder
91                .get_phi(phi.get_value_type(), incoming_values);
92
93            tail.insert_before(new_phi);
94
95            for (i, index) in backedge_blocks_index.iter().enumerate() {
96                phi.get_incoming_values_mut().remove(index - i);
97                unsafe { phi.get_manager_mut().remove_operand(index - i) };
98            }
99            phi.add_incoming_value(Operand::Instruction(new_phi), unique_backedge_block);
100
101            if let Some(next) = inst.get_next() {
102                inst = next;
103            }
104        }
105
106        backedge_blocks_index
107            .into_iter()
108            .map(|index| head.get_pred_bb()[index])
109            .collect::<Vec<_>>()
110            .into_iter()
111            .for_each(|mut bb| {
112                bb.replace_succ_bb_only(head, unique_backedge_block);
113            }); //
114
115        unique_backedge_block.set_true_bb(head);
116        lo.blocks.insert(unique_backedge_block);
117
118        Ok(())
119    }
120
121    fn insert_preheader(&mut self, mut lo: LoopPtr) -> Result<()> {
122        let header = lo.head;
123
124        // 获得不在循环中的bb和对应的index
125        let out_bb = header
126            .get_pred_bb()
127            .iter()
128            .enumerate()
129            .filter_map(|(index, bb)| {
130                if !lo.is_in_loop(bb) {
131                    Some((index, *bb))
132                } else {
133                    None
134                }
135            })
136            .collect::<Vec<_>>();
137
138        if out_bb.len() == 1 && out_bb[0].1.get_succ_bb().len() == 1 {
139            lo.pre_header = Some(out_bb[0].1);
140            return Ok(());
141        }
142
143        let mut preheader = self
144            .ir_builder
145            .new_basicblock("preheader".to_string() + &header.name);
146        let out_bb_index = out_bb
147            .into_iter()
148            .map(|(index, mut out_bb)| {
149                out_bb.replace_succ_bb_only(header, preheader);
150                index
151            })
152            .collect::<Vec<_>>();
153
154        preheader.set_true_bb(header);
155
156        let mut pre_header_jump = self.ir_builder.get_br(None);
157        preheader.push_back(pre_header_jump);
158        // 构建对应的phi结点
159        for mut phi in header.iter() {
160            if InstType::Phi != phi.get_type() {
161                break;
162            }
163
164            let phi = downcast_mut::<Phi>(phi.as_mut());
165            let incoming_values = out_bb_index
166                .iter()
167                .map(|&index| phi.get_incoming_values()[index].clone())
168                .collect::<Vec<_>>();
169
170            out_bb_index.iter().enumerate().for_each(|(i, index)| {
171                phi.get_incoming_values_mut().remove(index - i);
172                unsafe { phi.get_manager_mut().remove_operand(index - i) };
173            });
174
175            let new_phi = self
176                .ir_builder
177                .get_phi(phi.get_value_type(), incoming_values);
178            pre_header_jump.insert_before(new_phi);
179
180            phi.add_incoming_value(Operand::Instruction(new_phi), preheader);
181        }
182
183        // 如果是子循环,则preheader会存在上层循环中
184        if let Some(mut plo) = lo.parent_loop {
185            plo.blocks.insert(preheader);
186        }
187
188        lo.pre_header = Some(preheader);
189
190        Ok(())
191    }
192}