duskphantom_middle/transform/
sink_code.rs

1// Copyright 2024 Duskphantom Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17use std::collections::{HashMap, HashSet};
18
19use anyhow::{anyhow, Context, Result};
20
21use crate::context;
22use crate::analysis::dominator_tree::DominatorTree;
23use crate::analysis::effect_analysis::EffectAnalysis;
24use crate::ir::instruction::downcast_ref;
25use crate::ir::instruction::misc_inst::Phi;
26use crate::ir::instruction::{downcast_mut, InstType};
27use crate::ir::{BBPtr, InstPtr, Operand};
28use crate::Program;
29
30use super::Transform;
31
32#[allow(unused)]
33pub fn optimize_program(program: &mut Program) -> Result<bool> {
34    let effect_analysis = EffectAnalysis::new(program);
35    SinkCode::new(program, &effect_analysis).run_and_log()
36}
37
38#[allow(unused)]
39pub struct SinkCode<'a> {
40    program: &'a mut Program,
41    effect_analysis: &'a EffectAnalysis,
42}
43
44#[allow(unused)]
45impl<'a> Transform for SinkCode<'a> {
46    fn get_program_mut(&mut self) -> &mut Program {
47        self.program
48    }
49
50    fn name() -> String {
51        "sink_code".to_string()
52    }
53
54    fn run(&mut self) -> Result<bool> {
55        let mut changed = false;
56        for func in self.program.module.functions.clone() {
57            if func.is_lib() {
58                continue;
59            }
60            let mut dom_tree = DominatorTree::new(func);
61            for bb in func.po_iter() {
62                for inst in bb.iter_rev() {
63                    changed |= self.sink_inst(inst, &mut dom_tree)?;
64                }
65            }
66        }
67        Ok(true)
68    }
69}
70
71#[allow(unused)]
72impl<'a> SinkCode<'a> {
73    pub fn new(program: &'a mut Program, effect_analysis: &'a EffectAnalysis) -> Self {
74        Self {
75            program,
76            effect_analysis,
77        }
78    }
79
80    fn sink_inst(&mut self, mut inst: InstPtr, dom_tree: &mut DominatorTree) -> Result<bool> {
81        let mut changed = false;
82
83        // Refuse to sink instruction with side effect
84        if self.is_fixed(inst) {
85            return Ok(changed);
86        }
87
88        // If any user is in the same block, do not sink
89        // TODO even in same BB it can sink as low as possible
90        let root = inst
91            .get_parent_bb()
92            .ok_or_else(|| anyhow!("Instruction {} has no parent BB", inst))
93            .with_context(|| context!())?;
94        for user in FakeInst::from_inst_users(inst)? {
95            let parent_bb = user.get_parent_bb()?;
96            if root == parent_bb {
97                return Ok(changed);
98            }
99        }
100
101        // If there are two successors, sink into both and create necessary phi.
102        //
103        // Suppose bb dominates the two successors (A, B), and other blocks (C, D),
104        // if there are users in (A OR B) branch, we can only sink it into (A AND B).
105        // To remove partial redundancy we insert phi for (C, D), and phi can't be sunk.
106        //
107        // Otherwise we can sink it into (C, D), and sink them recursively.
108        // Time complexity is O(n * log(n)) because each time users are partitioned.
109        //
110        // TODO below is a temporary implementation, it refuses to sink if there are users in (C OR D).
111        if root.get_succ_bb().len() == 2 {
112            let mut block_to_user: HashMap<BBPtr, HashSet<FakeInst>> = HashMap::new();
113            for user in FakeInst::from_inst_users(inst)? {
114                let user_bb = user.get_parent_bb()?;
115                block_to_user.entry(user_bb).or_default().insert(user);
116            }
117
118            // Get mapping from dominatee to users
119            let mut dominatee_to_user: HashMap<BBPtr, HashSet<FakeInst>> = HashMap::new();
120            for (bb, _) in block_to_user.iter() {
121                let mut cursor = *bb;
122                let dominatee = loop {
123                    let idom = dom_tree
124                        .get_idom(cursor)
125                        .ok_or_else(|| {
126                            anyhow!("{} has no immediate dominator ({})", cursor.name, bb.name)
127                        })
128                        .with_context(|| context!())?;
129                    if idom == root {
130                        break cursor;
131                    }
132                    cursor = idom;
133                };
134                dominatee_to_user
135                    .entry(dominatee)
136                    .or_default()
137                    .extend(block_to_user[bb].clone());
138            }
139
140            // Check if there are users in (C OR D) branch
141            for (k, v) in dominatee_to_user.iter() {
142                if !root.get_succ_bb().contains(k) && !v.is_empty() {
143                    return Ok(changed);
144                }
145            }
146
147            // Check if each dominatee is not loop header
148            for dominatee in dominatee_to_user.keys() {
149                for pred in dominatee.get_pred_bb() {
150                    if dom_tree.is_dominate(*dominatee, *pred) {
151                        return Ok(changed);
152                    }
153                }
154            }
155
156            // Sink into each successor
157            changed = true;
158            for succ in root.get_succ_bb() {
159                let user = dominatee_to_user
160                    .get(succ)
161                    .cloned()
162                    .unwrap_or(HashSet::new());
163                if !user.is_empty() {
164                    let mut new_inst = self
165                        .program
166                        .mem_pool
167                        .copy_instruction(inst.as_ref().as_ref());
168                    for op in inst.get_operand() {
169                        new_inst.add_operand(op.clone());
170                    }
171
172                    // Get insert position (as low as possible)
173                    let user_in_succ = block_to_user.get(succ).cloned().unwrap_or(HashSet::new());
174                    let mut frontier = None;
175                    for inst in succ.iter() {
176                        if user_in_succ.contains(&FakeInst::Normal(inst)) {
177                            frontier = Some(inst);
178                            break;
179                        }
180                    }
181
182                    // Insert new instruction and maintain use-def chain
183                    if let Some(mut frontier) = frontier {
184                        frontier.insert_before(new_inst);
185                    } else {
186                        succ.get_last_inst().insert_before(new_inst);
187                    }
188                    for mut user in user {
189                        user.replace_operand(&inst.into(), &new_inst.into());
190                    }
191
192                    // Sink recursively
193                    self.sink_inst(new_inst, dom_tree)?;
194                }
195            }
196
197            // Remove the original instruction
198            inst.remove_self();
199        }
200
201        // If there is only one successor, and it's dominated
202        // (the successor has only one predecessor), sink into it
203        if root.get_succ_bb().len() == 1 {
204            let succ = root.get_succ_bb().first().unwrap();
205            if succ.get_pred_bb().len() == 1 {
206                changed = true;
207                succ.clone().push_front(inst);
208                self.sink_inst(inst, dom_tree)?;
209            }
210        }
211
212        Ok(changed)
213    }
214
215    fn is_fixed(&mut self, inst: InstPtr) -> bool {
216        matches!(
217            inst.get_type(),
218            InstType::Load | InstType::Store | InstType::Ret | InstType::Br | InstType::Phi
219        ) || self.effect_analysis.has_effect(inst)
220    }
221}
222
223/// We treat `phi` as multiple fake instructions, each corresponds to one incoming value.
224/// This way it's unified with normal instruction.
225#[derive(PartialEq, Eq, Hash, Clone)]
226enum FakeInst {
227    Normal(InstPtr),
228    Phi(InstPtr, BBPtr),
229}
230
231impl FakeInst {
232    fn from_inst_users(inst: InstPtr) -> Result<Vec<FakeInst>> {
233        inst.get_user()
234            .iter()
235            .map(|user| FakeInst::from_inst_user(inst, *user))
236            .collect::<Result<Vec<Vec<_>>>>()
237            .map(|v| v.into_iter().flatten().collect())
238    }
239
240    fn from_inst_user(inst: InstPtr, user: InstPtr) -> Result<Vec<FakeInst>> {
241        if user.get_type() == InstType::Phi {
242            let phi = downcast_ref::<Phi>(user.as_ref().as_ref());
243            let mut result = Vec::new();
244            for (op, bb) in phi.get_incoming_values() {
245                if op == &Operand::Instruction(inst) {
246                    result.push(FakeInst::Phi(user, *bb));
247                }
248            }
249            Ok(result)
250        } else {
251            Ok(vec![FakeInst::Normal(user)])
252        }
253    }
254
255    fn get_parent_bb(&self) -> Result<BBPtr> {
256        match self {
257            FakeInst::Normal(inst) => inst
258                .get_parent_bb()
259                .ok_or_else(|| anyhow!("Instruction {} has no parent BB", inst))
260                .with_context(|| context!()),
261            // Phi FakeInst locates in one of data source block
262            FakeInst::Phi(_, bb) => Ok(*bb),
263        }
264    }
265
266    fn replace_operand(&mut self, old: &Operand, new: &Operand) {
267        match self {
268            FakeInst::Normal(inst) => {
269                inst.replace_operand(old, new);
270            }
271            FakeInst::Phi(inst, bb) => {
272                let phi = downcast_mut::<Phi>(inst.as_mut());
273                phi.replace_incoming_value_at(*bb, new.clone());
274            }
275        }
276    }
277}