duskphantom_middle/transform/
mem2reg.rs

1// Copyright 2024 Duskphantom Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17use std::{
18    collections::{BTreeMap, BTreeSet, HashSet},
19    pin::Pin,
20};
21
22use anyhow::{Context, Result};
23
24use crate::{
25    analysis::dominator_tree::DominatorTree,
26    context,
27    ir::{
28        instruction::{downcast_mut, misc_inst::Phi, InstType},
29        BBPtr, FunPtr, IRBuilder, InstPtr, Operand, ValueType,
30    },
31    Program,
32};
33use duskphantom_utils::frame_map::FrameMap;
34
35use super::Transform;
36
37pub fn optimize_program(program: &mut Program) -> Result<bool> {
38    Mem2Reg::new(program).run_and_log()
39}
40
41pub struct Mem2Reg<'a> {
42    program: &'a mut Program,
43}
44
45impl<'a> Transform for Mem2Reg<'a> {
46    fn get_program_mut(&mut self) -> &mut Program {
47        self.program
48    }
49
50    fn name() -> String {
51        "mem2reg".to_string()
52    }
53
54    fn run(&mut self) -> Result<bool> {
55        let mut changed = false;
56        for func in &self.program.module.functions {
57            if !func.is_lib() {
58                changed |= mem2reg(*func, &mut self.program.mem_pool)?;
59            }
60        }
61        Ok(changed)
62    }
63}
64
65impl<'a> Mem2Reg<'a> {
66    pub fn new(program: &'a mut Program) -> Self {
67        Self { program }
68    }
69}
70
71/// A single argument of "phi" instruction
72type PhiArg = (BBPtr, Operand);
73
74/// Pack of a "phi" instruction with corresponding variable
75/// The variable is an "alloca" instruction
76struct PhiPack {
77    inst: InstPtr,
78    variable: InstPtr,
79}
80
81impl PhiPack {
82    /// Create a PhiPack from a variable
83    /// The variable is the "alloca" instruction to be eliminated
84    /// Errors when variable is not of pointer type
85    pub fn new_from_variable(
86        variable: InstPtr,
87        mem_pool: &mut Pin<Box<IRBuilder>>,
88        bb: &mut BBPtr,
89    ) -> Result<Self> {
90        // Get type of phi variable
91        let ValueType::Pointer(ty) = variable.get_value_type() else {
92            return Err(anyhow::anyhow!("variable type is not pointer"))
93                .with_context(|| context!());
94        };
95
96        // Get and insert empty "phi" instruction
97        let phi = mem_pool.get_phi(*ty, vec![]);
98        bb.push_front(phi);
99
100        // Return phi pack
101        Ok(Self {
102            inst: phi,
103            variable,
104        })
105    }
106
107    /// Add an argument to the "phi" instruction
108    pub fn add_argument(&mut self, phi_arg: PhiArg) {
109        // Get mutable reference of phi
110        let phi = downcast_mut::<Phi>(self.inst.as_mut());
111
112        // Add argument to phi
113        phi.add_incoming_value(phi_arg.1, phi_arg.0);
114    }
115}
116
117/// The mem2reg pass
118#[allow(unused)]
119pub fn mem2reg(func: FunPtr, mem_pool: &mut Pin<Box<IRBuilder>>) -> Result<bool> {
120    let entry = func.entry.unwrap();
121    let mut variable_to_phi_insertion: BTreeMap<InstPtr, BTreeSet<BBPtr>> =
122        get_variable_to_phi_insertion(func);
123    let mut block_to_phi_insertion: BTreeMap<BBPtr, Vec<PhiPack>> =
124        insert_empty_phi(entry, mem_pool, variable_to_phi_insertion)?;
125
126    /// For each "phi" insert position, decide the value for each argument
127    /// Errors when variable is not found in current_variable_value
128    fn decide_variable_value(
129        variable: InstPtr,
130        current_variable_value: &FrameMap<InstPtr, Operand>,
131    ) -> Result<Operand> {
132        if let Some(value) = current_variable_value.get(&variable) {
133            return Ok(value.clone());
134        }
135        let ValueType::Pointer(ptr) = variable.get_value_type() else {
136            return Err(anyhow::anyhow!("variable type is not pointer"))
137                .with_context(|| context!());
138        };
139
140        // Value not found can happen when out of scope of a variable, or not defined
141        // To keep consistent with LLVM, return default initializer
142        Ok(Operand::Constant(ptr.default_initializer()?))
143    }
144
145    /// Start from entry node, decide the value for each "phi" instruction
146    /// This will also remove "load" and "store" instructions when possible
147    fn decide_values_start_from(
148        parent_bb: Option<BBPtr>,
149        current_bb: BBPtr,
150        visited: &mut BTreeSet<BBPtr>,
151        current_variable_value: &mut FrameMap<InstPtr, Operand>,
152        block_to_phi_insertion: &mut BTreeMap<BBPtr, Vec<PhiPack>>,
153    ) -> Result<bool> {
154        let mut changed = false;
155
156        // Decide value for each "phi" instruction to add
157        for mut phi in block_to_phi_insertion
158            .get_mut(&current_bb)
159            .unwrap_or(&mut vec![])
160            .iter_mut()
161        {
162            let value = decide_variable_value(phi.variable, current_variable_value)?;
163            phi.add_argument((parent_bb.unwrap(), value));
164            current_variable_value.insert(phi.variable, Operand::Instruction(phi.inst));
165            changed = true;
166        }
167
168        // Do not continue if visited
169        // "phi" instruction can be added multiple times for each basic block,
170        // so that part is before this check
171        if visited.contains(&current_bb) {
172            return Ok(changed);
173        }
174        visited.insert(current_bb);
175
176        // Iterate all instructions and:
177        //
178        // 1. for each "store", update current variable value and remove the "store"
179        // 2. for each "load", replace with the current variable value
180        //
181        // Bypass if featured variable is not a constant pointer,
182        // for example if it's calculated from "getelementptr"
183        for mut inst in current_bb.iter() {
184            match inst.get_type() {
185                InstType::Store => {
186                    let store_operands = inst.get_operand();
187                    let store_ptr = &store_operands[1];
188                    let store_value = &store_operands[0];
189
190                    // Update only when store destination is a constant pointer
191                    if let Operand::Instruction(variable) = store_ptr {
192                        if variable.get_type() == InstType::Alloca {
193                            current_variable_value.insert(*variable, store_value.clone());
194                            inst.remove_self();
195                            changed = true;
196                        }
197                    }
198                }
199                InstType::Load => {
200                    let load_operands = inst.get_operand();
201                    let load_ptr = &load_operands[0];
202
203                    // Replace only when load source is a constant pointer
204                    if let Operand::Instruction(variable) = load_ptr {
205                        if variable.get_type() == InstType::Alloca {
206                            let current_value =
207                                decide_variable_value(*variable, current_variable_value)?;
208                            inst.replace_self(&current_value);
209                            changed = true;
210                        }
211                    }
212                }
213                _ => (),
214            }
215        }
216
217        // Visit all successors
218        let successors = current_bb.get_succ_bb();
219        for succ in successors {
220            changed |= decide_values_start_from(
221                Some(current_bb),
222                *succ,
223                visited,
224                &mut current_variable_value.branch(),
225                block_to_phi_insertion,
226            )?;
227        }
228        Ok(changed)
229    }
230
231    // Start mem2reg pass from the entry block
232    decide_values_start_from(
233        None,
234        entry,
235        &mut BTreeSet::new(),
236        &mut FrameMap::new(),
237        &mut block_to_phi_insertion,
238    )
239}
240
241/// Insert empty "phi" for basic blocks starting from `entry`
242/// Returns a mapping from basic block to inserted "phi" instructions
243#[allow(unused)]
244fn insert_empty_phi(
245    entry: BBPtr,
246    mem_pool: &mut Pin<Box<IRBuilder>>,
247    phi_insert_positions: BTreeMap<InstPtr, BTreeSet<BBPtr>>,
248) -> Result<BTreeMap<BBPtr, Vec<PhiPack>>> {
249    let mut block_to_phi_insertion: BTreeMap<BBPtr, Vec<PhiPack>> = BTreeMap::new();
250    for (variable, positions) in phi_insert_positions.into_iter() {
251        for mut position in positions.into_iter() {
252            block_to_phi_insertion
253                .entry(position)
254                .or_default()
255                .push(PhiPack::new_from_variable(
256                    variable,
257                    mem_pool,
258                    &mut position,
259                )?);
260        }
261    }
262    Ok(block_to_phi_insertion)
263}
264
265/// Get places to insert "phi" instructions for each "alloca" instruction
266#[allow(unused)]
267fn get_variable_to_phi_insertion(func: FunPtr) -> BTreeMap<InstPtr, BTreeSet<BBPtr>> {
268    let entry = func.entry.unwrap();
269    let mut phi_positions: BTreeMap<InstPtr, BTreeSet<BBPtr>> = BTreeMap::new();
270    let mut store_positions: BTreeMap<InstPtr, BTreeSet<BBPtr>> = BTreeMap::new();
271    let mut dom_tree = DominatorTree::new(func);
272
273    /// Build a mapping from variable to store positions
274    fn build_store_positions(
275        current_bb: BBPtr,
276        visited: &mut HashSet<BBPtr>,
277        store_positions: &mut BTreeMap<InstPtr, BTreeSet<BBPtr>>,
278    ) {
279        if visited.contains(&current_bb) {
280            return;
281        }
282        visited.insert(current_bb);
283        for inst in current_bb.iter() {
284            if inst.get_type() == InstType::Store {
285                let store = inst;
286                let store_operands = store.get_operand();
287                let store_ptr = &store_operands[1];
288
289                // Only insert "phi" when store destination is a constant pointer
290                if let Operand::Instruction(inst) = store_ptr {
291                    if inst.get_type() == InstType::Alloca {
292                        store_positions.entry(*inst).or_default().insert(current_bb);
293                    }
294                }
295            }
296        }
297        for succ in current_bb.get_succ_bb() {
298            build_store_positions(*succ, visited, store_positions);
299        }
300    }
301
302    // For each variable, find all dominance frontiers and insert "phi" instructions
303    // After inserting "phi" at a block, find its dominance frontiers and insert "phi" recursively
304    build_store_positions(entry, &mut HashSet::new(), &mut store_positions);
305    for (variable, vis) in store_positions.iter_mut() {
306        let mut positions = vis.clone();
307        while !positions.is_empty() {
308            let position = positions.pop_first().unwrap();
309            let df = dom_tree.get_df(position);
310            for bb in df {
311                phi_positions.entry(*variable).or_default().insert(bb);
312
313                // Only insert positions never considered before
314                if (!vis.contains(&bb)) {
315                    vis.insert(bb);
316                    positions.insert(bb);
317                }
318            }
319        }
320    }
321
322    // Return result
323    phi_positions
324}
325
326#[cfg(test)]
327pub mod tests_mem2reg {
328    use super::*;
329    use crate::{ir::ValueType, Program};
330
331    #[test]
332    fn test_phi_insert_positions_single() {
333        let mut program = Program::new();
334
335        // Construct a function with "alloca" and "store" instructions
336        let mut entry = program.mem_pool.new_basicblock("entry".to_string());
337        let alloca1 = program.mem_pool.get_alloca(ValueType::Int, 1);
338        let alloca2 = program.mem_pool.get_alloca(ValueType::Int, 1);
339        let alloca3 = program.mem_pool.get_alloca(ValueType::Int, 1);
340        let store1 = program
341            .mem_pool
342            .get_store(Operand::Constant(1.into()), Operand::Instruction(alloca1));
343        let store2 = program
344            .mem_pool
345            .get_store(Operand::Constant(1.into()), Operand::Instruction(alloca2));
346        let store3 = program
347            .mem_pool
348            .get_store(Operand::Constant(1.into()), Operand::Instruction(alloca3));
349        entry.push_back(alloca1);
350        entry.push_back(alloca2);
351        entry.push_back(alloca3);
352        entry.push_back(store1);
353        entry.push_back(store2);
354        entry.push_back(store3);
355        let mut func = program
356            .mem_pool
357            .new_function("no_name".to_string(), crate::ir::ValueType::Void);
358        func.entry = Some(entry);
359        func.exit = Some(entry);
360
361        // Calculate df and phi insert positions
362        let phi_positions = get_variable_to_phi_insertion(func);
363
364        // Check if phi insert positions are correct
365        assert_eq!(phi_positions.len(), 0);
366    }
367
368    #[test]
369    fn test_phi_insert_positions_nested() {
370        let mut program = Program::new();
371
372        // Construct a nested if-else graph
373        let mut entry = program.mem_pool.new_basicblock("entry".to_string());
374        let mut then = program.mem_pool.new_basicblock("then".to_string());
375        let mut then_then = program.mem_pool.new_basicblock("then_then".to_string());
376        let mut then_alt = program.mem_pool.new_basicblock("then_alt".to_string());
377        let mut alt = program.mem_pool.new_basicblock("alt".to_string());
378        let end = program.mem_pool.new_basicblock("end".to_string());
379        entry.set_true_bb(then);
380        entry.set_false_bb(alt);
381        then.set_true_bb(then_then);
382        then.set_false_bb(then_alt);
383        then_then.set_true_bb(end);
384        then_alt.set_true_bb(end);
385        alt.set_true_bb(end);
386        let mut func = program
387            .mem_pool
388            .new_function("no_name".to_string(), crate::ir::ValueType::Void);
389        func.entry = Some(entry);
390        func.exit = Some(end);
391
392        // Construct "alloca" and "store" instructions
393        let alloca1 = program.mem_pool.get_alloca(ValueType::Int, 1);
394        let alloca2 = program.mem_pool.get_alloca(ValueType::Int, 1);
395        let alloca3 = program.mem_pool.get_alloca(ValueType::Int, 1);
396        let store1 = program
397            .mem_pool
398            .get_store(Operand::Constant(1.into()), Operand::Instruction(alloca1));
399        let store2 = program
400            .mem_pool
401            .get_store(Operand::Constant(1.into()), Operand::Instruction(alloca2));
402        let store3 = program
403            .mem_pool
404            .get_store(Operand::Constant(1.into()), Operand::Instruction(alloca3));
405        then.push_back(alloca1);
406        then.push_back(store1);
407        then_alt.push_back(alloca2);
408        then_alt.push_back(store2);
409        alt.push_back(alloca3);
410        alt.push_back(store3);
411
412        // Calculate phi insert positions
413        let phi_positions = get_variable_to_phi_insertion(func);
414
415        // Check if phi insert positions are correct
416        assert_eq!(phi_positions.len(), 3);
417        assert_eq!(phi_positions[&alloca1].len(), 1);
418        assert_eq!(phi_positions[&alloca2].len(), 1);
419        assert_eq!(phi_positions[&alloca3].len(), 1);
420        assert!(phi_positions[&alloca1].contains(&end));
421        assert!(phi_positions[&alloca2].contains(&end));
422        assert!(phi_positions[&alloca3].contains(&end));
423    }
424}