duskphantom_middle/transform/
func_inline.rs

1// Copyright 2024 Duskphantom Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17use std::collections::HashMap;
18
19use crate::ir::instruction::downcast_ref;
20use crate::{
21    analysis::call_graph::{CallEdge, CallGraph},
22    context,
23    ir::{
24        instruction::{
25            downcast_mut,
26            misc_inst::{Call, Phi},
27            InstType,
28        },
29        BBPtr, FunPtr, InstPtr, Instruction, Operand, ParaPtr, ValueType,
30    },
31    Program,
32};
33use anyhow::{anyhow, Context, Result};
34use duskphantom_utils::paral_counter::ParalCounter;
35
36use super::Transform;
37
38pub fn optimize_program(program: &mut Program) -> Result<bool> {
39    let mut call_graph = CallGraph::new(program);
40    let counter = ParalCounter::new(0, usize::MAX);
41    let mut func_inline = FuncInline::new(program, &mut call_graph, counter);
42    func_inline.run_and_log()
43}
44
45pub struct FuncInline<'a> {
46    program: &'a mut Program,
47    call_graph: &'a mut CallGraph,
48    counter: ParalCounter,
49}
50
51impl<'a> Transform for FuncInline<'a> {
52    fn get_program_mut(&mut self) -> &mut Program {
53        self.program
54    }
55
56    fn name() -> String {
57        "func_inline".to_string()
58    }
59
60    fn run(&mut self) -> Result<bool> {
61        let mut whole_changed = false;
62        loop {
63            let mut changed = false;
64            for func in self.program.module.functions.clone() {
65                // Do not process library function
66                if func.is_lib() {
67                    continue;
68                }
69
70                // If functions calls other functions, do not process it
71                if !self.call_graph.get_calls(func).is_empty() {
72                    continue;
73                }
74
75                // Process function
76                changed |= self.process_func(func)?;
77                whole_changed |= changed;
78
79                // Update call graph
80                self.call_graph.remove(func);
81            }
82            if !changed {
83                break;
84            }
85        }
86        Ok(whole_changed)
87    }
88}
89
90impl<'a> FuncInline<'a> {
91    pub fn new(
92        program: &'a mut Program,
93        call_graph: &'a mut CallGraph,
94        counter: ParalCounter,
95    ) -> Self {
96        Self {
97            program,
98            call_graph,
99            counter,
100        }
101    }
102
103    fn process_func(&mut self, func: FunPtr) -> Result<bool> {
104        let mut changed = false;
105
106        // Eliminate call to func
107        for call in self.call_graph.get_called_by(func) {
108            changed |= self.process_call(call)?;
109        }
110
111        // Delete func to reduce code size
112        if changed {
113            self.program.module.functions.retain(|&f| f != func);
114        }
115        Ok(changed)
116    }
117
118    fn process_call(&mut self, edge: CallEdge) -> Result<bool> {
119        let mut inst = edge.inst;
120        let call = downcast_ref::<Call>(inst.as_ref().as_ref());
121
122        // Build argument map
123        let params = edge.callee.params.iter().cloned();
124        let args = inst.get_operand().iter().cloned();
125        let arg_map = params.zip(args).collect();
126
127        // Mirror function, focus on interface basic blocks
128        let new_fun = self.mirror_func(edge.callee, arg_map)?;
129        let mut before_entry = call.get_parent_bb().unwrap();
130        let after_exit = self.split_block_at(before_entry, inst)?;
131        let fun_entry = new_fun
132            .entry
133            .ok_or_else(|| anyhow!("function `{}` has no entry", new_fun.name))
134            .with_context(|| context!())?;
135        let mut fun_exit = new_fun
136            .exit
137            .ok_or_else(|| anyhow!("function `{}` has no exit", new_fun.name))
138            .with_context(|| context!())?;
139
140        // Wire before_entry -> fun_entry
141        before_entry.push_back(self.program.mem_pool.get_br(None));
142        before_entry.set_true_bb(fun_entry);
143
144        // Replace call with operand of return, remove return
145        let mut ret = fun_exit.get_last_inst();
146        if inst.get_value_type() == ValueType::Void {
147            inst.remove_self();
148        } else {
149            let ret_val = ret
150                .get_operand()
151                .first()
152                .ok_or_else(|| anyhow!("function `{}` has no return value", new_fun.name))
153                .with_context(|| context!())?;
154            inst.replace_self(ret_val);
155        }
156        ret.remove_self();
157
158        // Wire func_exit -> after_exit
159        fun_exit.push_back(self.program.mem_pool.get_br(None));
160        fun_exit.set_true_bb(after_exit);
161        Ok(true)
162    }
163
164    /// Split given basic block at the position of given instruction.
165    /// Given instruction and instruction afterwards will be put to exit block.
166    /// Returns new exit block.
167    fn split_block_at(&mut self, mut entry: BBPtr, inst: InstPtr) -> Result<BBPtr> {
168        let exit_name = self.unique_name("split", &entry.name);
169        let mut exit = self.program.mem_pool.new_basicblock(exit_name);
170        let mut split = false;
171
172        // Copy instructions after found target instruction
173        for bb_inst in entry.iter() {
174            if bb_inst == inst {
175                split = true;
176            }
177            if split {
178                exit.push_back(bb_inst);
179            }
180        }
181
182        // Replace `entry` with `entry -> exit`
183        entry.replace_exit(exit);
184
185        // Return created block
186        Ok(exit)
187    }
188
189    /// Mirror a function with given mapping.
190    /// The function should not be added to program, please wire entry and exit to existing function.
191    fn mirror_func(&mut self, func: FunPtr, arg_map: HashMap<ParaPtr, Operand>) -> Result<FunPtr> {
192        let func_entry = func
193            .entry
194            .ok_or_else(|| anyhow!("function `{}` has no entry", func.name))
195            .with_context(|| context!())?;
196        let func_exit = func
197            .exit
198            .ok_or_else(|| anyhow!("function `{}` has no exit", func.name))
199            .with_context(|| context!())?;
200
201        // Initialize inst and block mapping and new function
202        let mut inst_map: HashMap<InstPtr, InstPtr> = HashMap::new();
203        let mut block_map: HashMap<BBPtr, BBPtr> = HashMap::new();
204        let mut new_fun = self
205            .program
206            .mem_pool
207            .new_function(String::new(), func.return_type.clone());
208
209        // Copy blocks and instructions
210        for bb in func.dfs_iter() {
211            let name = self.unique_name("inline", &bb.name);
212            let mut new_bb = self.program.mem_pool.new_basicblock(name);
213            block_map.insert(bb, new_bb);
214            for inst in bb.iter() {
215                let new_inst = self
216                    .program
217                    .mem_pool
218                    .copy_instruction(inst.as_ref().as_ref());
219                inst_map.insert(inst, new_inst);
220                new_bb.push_back(new_inst);
221            }
222        }
223
224        // Set entry and exit for new function
225        new_fun.entry = block_map.get(&func_entry).cloned();
226        new_fun.exit = block_map.get(&func_exit).cloned();
227
228        // Copy operands from old instruction to new instruction,
229        // replace operands to local instruction and inlined argument
230        for bb in func.dfs_iter() {
231            for inst in bb.iter() {
232                let mut new_inst = inst_map
233                    .get(&inst)
234                    .cloned()
235                    .ok_or_else(|| anyhow!("instruction not found in inst_map: {}", inst))
236                    .with_context(|| context!())?;
237                if inst.get_type() == InstType::Phi {
238                    let inst = downcast_ref::<Phi>(inst.as_ref().as_ref());
239                    let new_inst = downcast_mut::<Phi>(new_inst.as_mut());
240
241                    // Replace operand for phi instruction
242                    for (old_op, old_bb) in inst.get_incoming_values().iter() {
243                        let new_bb = block_map
244                            .get(old_bb)
245                            .cloned()
246                            .ok_or_else(|| anyhow!("bb not found in block_map: {}", old_bb.name))
247                            .with_context(|| context!())?;
248                        if let Operand::Instruction(old_op) = old_op {
249                            let new_op = inst_map.get(old_op).cloned().unwrap();
250                            new_inst.add_incoming_value(new_op.into(), new_bb);
251                        } else if let Operand::Parameter(old_op) = old_op {
252                            let new_op = arg_map.get(old_op).cloned().unwrap();
253                            new_inst.add_incoming_value(new_op, new_bb);
254                        } else {
255                            // Copy operands manually because `copy_instruction` does not copy them
256                            new_inst.add_incoming_value(old_op.clone(), new_bb);
257                        }
258                    }
259                } else {
260                    // Replace operand for normal instruction
261                    for old_op in inst.get_operand().iter() {
262                        if let Operand::Instruction(old_op) = old_op {
263                            let new_op = inst_map.get(old_op).cloned().unwrap();
264                            new_inst.add_operand(new_op.into());
265                        } else if let Operand::Parameter(old_op) = old_op {
266                            let new_op = arg_map.get(old_op).cloned().unwrap();
267                            new_inst.add_operand(new_op);
268                        } else {
269                            // Copy operands manually because `copy_instruction` does not copy them
270                            new_inst.add_operand(old_op.clone());
271                        }
272                    }
273                }
274            }
275        }
276
277        // Assign mapped basic blocks to successor
278        for bb in func.dfs_iter() {
279            let mut new_bb = block_map.get(&bb).cloned().unwrap();
280            let succ_bb = bb.get_succ_bb();
281            if !succ_bb.is_empty() {
282                let new_succ = block_map.get(&succ_bb[0]).cloned().unwrap();
283                new_bb.set_true_bb(new_succ);
284            }
285            if succ_bb.len() >= 2 {
286                let new_succ = block_map.get(&succ_bb[1]).cloned().unwrap();
287                new_bb.set_false_bb(new_succ);
288            }
289        }
290
291        // Return new function
292        Ok(new_fun)
293    }
294
295    fn unique_name(&mut self, meta: &str, base_name: &str) -> String {
296        format!("{}_{}{}", base_name, meta, self.counter.get_id().unwrap())
297    }
298}