duskphantom_middle/transform/
func_inline.rs1use std::collections::HashMap;
18
19use crate::ir::instruction::downcast_ref;
20use crate::{
21 analysis::call_graph::{CallEdge, CallGraph},
22 context,
23 ir::{
24 instruction::{
25 downcast_mut,
26 misc_inst::{Call, Phi},
27 InstType,
28 },
29 BBPtr, FunPtr, InstPtr, Instruction, Operand, ParaPtr, ValueType,
30 },
31 Program,
32};
33use anyhow::{anyhow, Context, Result};
34use duskphantom_utils::paral_counter::ParalCounter;
35
36use super::Transform;
37
38pub fn optimize_program(program: &mut Program) -> Result<bool> {
39 let mut call_graph = CallGraph::new(program);
40 let counter = ParalCounter::new(0, usize::MAX);
41 let mut func_inline = FuncInline::new(program, &mut call_graph, counter);
42 func_inline.run_and_log()
43}
44
45pub struct FuncInline<'a> {
46 program: &'a mut Program,
47 call_graph: &'a mut CallGraph,
48 counter: ParalCounter,
49}
50
51impl<'a> Transform for FuncInline<'a> {
52 fn get_program_mut(&mut self) -> &mut Program {
53 self.program
54 }
55
56 fn name() -> String {
57 "func_inline".to_string()
58 }
59
60 fn run(&mut self) -> Result<bool> {
61 let mut whole_changed = false;
62 loop {
63 let mut changed = false;
64 for func in self.program.module.functions.clone() {
65 if func.is_lib() {
67 continue;
68 }
69
70 if !self.call_graph.get_calls(func).is_empty() {
72 continue;
73 }
74
75 changed |= self.process_func(func)?;
77 whole_changed |= changed;
78
79 self.call_graph.remove(func);
81 }
82 if !changed {
83 break;
84 }
85 }
86 Ok(whole_changed)
87 }
88}
89
90impl<'a> FuncInline<'a> {
91 pub fn new(
92 program: &'a mut Program,
93 call_graph: &'a mut CallGraph,
94 counter: ParalCounter,
95 ) -> Self {
96 Self {
97 program,
98 call_graph,
99 counter,
100 }
101 }
102
103 fn process_func(&mut self, func: FunPtr) -> Result<bool> {
104 let mut changed = false;
105
106 for call in self.call_graph.get_called_by(func) {
108 changed |= self.process_call(call)?;
109 }
110
111 if changed {
113 self.program.module.functions.retain(|&f| f != func);
114 }
115 Ok(changed)
116 }
117
118 fn process_call(&mut self, edge: CallEdge) -> Result<bool> {
119 let mut inst = edge.inst;
120 let call = downcast_ref::<Call>(inst.as_ref().as_ref());
121
122 let params = edge.callee.params.iter().cloned();
124 let args = inst.get_operand().iter().cloned();
125 let arg_map = params.zip(args).collect();
126
127 let new_fun = self.mirror_func(edge.callee, arg_map)?;
129 let mut before_entry = call.get_parent_bb().unwrap();
130 let after_exit = self.split_block_at(before_entry, inst)?;
131 let fun_entry = new_fun
132 .entry
133 .ok_or_else(|| anyhow!("function `{}` has no entry", new_fun.name))
134 .with_context(|| context!())?;
135 let mut fun_exit = new_fun
136 .exit
137 .ok_or_else(|| anyhow!("function `{}` has no exit", new_fun.name))
138 .with_context(|| context!())?;
139
140 before_entry.push_back(self.program.mem_pool.get_br(None));
142 before_entry.set_true_bb(fun_entry);
143
144 let mut ret = fun_exit.get_last_inst();
146 if inst.get_value_type() == ValueType::Void {
147 inst.remove_self();
148 } else {
149 let ret_val = ret
150 .get_operand()
151 .first()
152 .ok_or_else(|| anyhow!("function `{}` has no return value", new_fun.name))
153 .with_context(|| context!())?;
154 inst.replace_self(ret_val);
155 }
156 ret.remove_self();
157
158 fun_exit.push_back(self.program.mem_pool.get_br(None));
160 fun_exit.set_true_bb(after_exit);
161 Ok(true)
162 }
163
164 fn split_block_at(&mut self, mut entry: BBPtr, inst: InstPtr) -> Result<BBPtr> {
168 let exit_name = self.unique_name("split", &entry.name);
169 let mut exit = self.program.mem_pool.new_basicblock(exit_name);
170 let mut split = false;
171
172 for bb_inst in entry.iter() {
174 if bb_inst == inst {
175 split = true;
176 }
177 if split {
178 exit.push_back(bb_inst);
179 }
180 }
181
182 entry.replace_exit(exit);
184
185 Ok(exit)
187 }
188
189 fn mirror_func(&mut self, func: FunPtr, arg_map: HashMap<ParaPtr, Operand>) -> Result<FunPtr> {
192 let func_entry = func
193 .entry
194 .ok_or_else(|| anyhow!("function `{}` has no entry", func.name))
195 .with_context(|| context!())?;
196 let func_exit = func
197 .exit
198 .ok_or_else(|| anyhow!("function `{}` has no exit", func.name))
199 .with_context(|| context!())?;
200
201 let mut inst_map: HashMap<InstPtr, InstPtr> = HashMap::new();
203 let mut block_map: HashMap<BBPtr, BBPtr> = HashMap::new();
204 let mut new_fun = self
205 .program
206 .mem_pool
207 .new_function(String::new(), func.return_type.clone());
208
209 for bb in func.dfs_iter() {
211 let name = self.unique_name("inline", &bb.name);
212 let mut new_bb = self.program.mem_pool.new_basicblock(name);
213 block_map.insert(bb, new_bb);
214 for inst in bb.iter() {
215 let new_inst = self
216 .program
217 .mem_pool
218 .copy_instruction(inst.as_ref().as_ref());
219 inst_map.insert(inst, new_inst);
220 new_bb.push_back(new_inst);
221 }
222 }
223
224 new_fun.entry = block_map.get(&func_entry).cloned();
226 new_fun.exit = block_map.get(&func_exit).cloned();
227
228 for bb in func.dfs_iter() {
231 for inst in bb.iter() {
232 let mut new_inst = inst_map
233 .get(&inst)
234 .cloned()
235 .ok_or_else(|| anyhow!("instruction not found in inst_map: {}", inst))
236 .with_context(|| context!())?;
237 if inst.get_type() == InstType::Phi {
238 let inst = downcast_ref::<Phi>(inst.as_ref().as_ref());
239 let new_inst = downcast_mut::<Phi>(new_inst.as_mut());
240
241 for (old_op, old_bb) in inst.get_incoming_values().iter() {
243 let new_bb = block_map
244 .get(old_bb)
245 .cloned()
246 .ok_or_else(|| anyhow!("bb not found in block_map: {}", old_bb.name))
247 .with_context(|| context!())?;
248 if let Operand::Instruction(old_op) = old_op {
249 let new_op = inst_map.get(old_op).cloned().unwrap();
250 new_inst.add_incoming_value(new_op.into(), new_bb);
251 } else if let Operand::Parameter(old_op) = old_op {
252 let new_op = arg_map.get(old_op).cloned().unwrap();
253 new_inst.add_incoming_value(new_op, new_bb);
254 } else {
255 new_inst.add_incoming_value(old_op.clone(), new_bb);
257 }
258 }
259 } else {
260 for old_op in inst.get_operand().iter() {
262 if let Operand::Instruction(old_op) = old_op {
263 let new_op = inst_map.get(old_op).cloned().unwrap();
264 new_inst.add_operand(new_op.into());
265 } else if let Operand::Parameter(old_op) = old_op {
266 let new_op = arg_map.get(old_op).cloned().unwrap();
267 new_inst.add_operand(new_op);
268 } else {
269 new_inst.add_operand(old_op.clone());
271 }
272 }
273 }
274 }
275 }
276
277 for bb in func.dfs_iter() {
279 let mut new_bb = block_map.get(&bb).cloned().unwrap();
280 let succ_bb = bb.get_succ_bb();
281 if !succ_bb.is_empty() {
282 let new_succ = block_map.get(&succ_bb[0]).cloned().unwrap();
283 new_bb.set_true_bb(new_succ);
284 }
285 if succ_bb.len() >= 2 {
286 let new_succ = block_map.get(&succ_bb[1]).cloned().unwrap();
287 new_bb.set_false_bb(new_succ);
288 }
289 }
290
291 Ok(new_fun)
293 }
294
295 fn unique_name(&mut self, meta: &str, base_name: &str) -> String {
296 format!("{}_{}{}", base_name, meta, self.counter.get_id().unwrap())
297 }
298}