Skip to main content

luaur_compiler/methods/
compiler_compile_inlined_call.rs

1use crate::enums::type_constant_folding::Type;
2use crate::functions::analyze_builtins::analyze_builtins;
3use crate::functions::fold_constants::fold_constants;
4use crate::functions::undo_changes_constant_folding::undo_changes_dense_hash_map_ast_expr_constant_expr_constant_change_log;
5use crate::functions::undo_changes_constant_folding_alt_b::undo_changes_dense_hash_map_ast_local_constant_local_constant_change_log;
6use crate::records::compiler::Compiler;
7use crate::records::constant::{Constant, ConstantData};
8use crate::records::inline_arg::InlineArg;
9use crate::records::inline_frame::InlineFrame;
10use luaur_ast::records::ast_expr::AstExpr;
11use luaur_ast::records::ast_expr_call::AstExprCall;
12use luaur_ast::records::ast_expr_function::AstExprFunction;
13use luaur_ast::records::ast_expr_varargs::AstExprVarargs;
14use luaur_ast::records::ast_local::AstLocal;
15use luaur_ast::records::ast_node::AstNode;
16use luaur_ast::records::ast_stat::AstStat;
17use luaur_common::enums::luau_opcode::LuauOpcode;
18
19const K_INVALID_REG: u8 = 255;
20const K_DEFAULT_ALLOC_PC: u32 = !0u32;
21
22impl Compiler {
23    pub fn compile_inlined_call(
24        &mut self,
25        expr: *mut AstExprCall,
26        func: *mut AstExprFunction,
27        target: u8,
28        target_count: u8,
29    ) {
30        unsafe {
31            let mut rs = self.reg_scope_compiler();
32            let _ = &mut rs;
33            let old_locals = self.local_stack.len();
34            let mut args: Vec<InlineArg> = Vec::new();
35            args.reserve((*func).args.size);
36
37            let func_args_size = (*func).args.size;
38            let expr_args_size = (*expr).args.size;
39
40            // evaluate all arguments; note that we don't emit code for constant arguments (relying on constant folding)
41            let mut i = 0usize;
42            while i < func_args_size {
43                let var: *mut AstLocal = *(*func).args.data.add(i);
44                let arg: *mut AstExpr = if i < expr_args_size {
45                    *(*expr).args.data.add(i)
46                } else {
47                    core::ptr::null_mut()
48                };
49
50                if i + 1 == expr_args_size
51                    && func_args_size > expr_args_size
52                    && self.is_expr_mult_ret(arg)
53                {
54                    let tail: u32 = (func_args_size - expr_args_size) as u32 + 1;
55                    let reg = self.alloc_reg(arg as *mut AstNode, tail);
56                    let allocpc = (*self.bytecode).get_debug_pc();
57
58                    let call = luaur_ast::rtti::ast_node_as::<AstExprCall>(arg as *mut AstNode);
59                    if !call.is_null() {
60                        self.compile_expr_call(call, reg, tail as u8, true, false);
61                    } else {
62                        let va =
63                            luaur_ast::rtti::ast_node_as::<AstExprVarargs>(arg as *mut AstNode);
64                        if !va.is_null() {
65                            self.compile_expr_varargs(va, reg, tail as u8, false);
66                        } else {
67                            luaur_common::macros::luau_assert::LUAU_ASSERT!(
68                                false,
69                                "Unexpected expression type"
70                            );
71                        }
72                    }
73
74                    let mut j = i;
75                    while j < func_args_size {
76                        args.push(InlineArg {
77                            local: *(*func).args.data.add(j),
78                            reg: reg + (j - i) as u8,
79                            value: Constant {
80                                r#type: Type::Type_Unknown,
81                                string_length: 0,
82                                data: ConstantData { value_number: 0.0 },
83                            },
84                            allocpc,
85                            init: core::ptr::null_mut(),
86                        });
87                        j += 1;
88                    }
89                    break;
90                } else if {
91                    let vv = self.variables.find(&var);
92                    vv.map_or(false, |vv| vv.written)
93                } {
94                    let reg = self.alloc_reg(arg as *mut AstNode, 1u32);
95                    let allocpc = (*self.bytecode).get_debug_pc();
96                    if !arg.is_null() {
97                        self.compile_expr_temp(arg, reg);
98                    } else {
99                        (*self.bytecode).emit_abc(LuauOpcode::LOP_LOADNIL, reg, 0, 0);
100                    }
101                    args.push(InlineArg {
102                        local: var,
103                        reg,
104                        value: Constant {
105                            r#type: Type::Type_Unknown,
106                            string_length: 0,
107                            data: ConstantData { value_number: 0.0 },
108                        },
109                        allocpc,
110                        init: core::ptr::null_mut(),
111                    });
112                } else if arg.is_null() {
113                    args.push(InlineArg {
114                        local: var,
115                        reg: K_INVALID_REG,
116                        value: Constant {
117                            r#type: Type::Type_Nil,
118                            string_length: 0,
119                            data: ConstantData { value_number: 0.0 },
120                        },
121                        allocpc: K_DEFAULT_ALLOC_PC,
122                        init: core::ptr::null_mut(),
123                    });
124                } else if {
125                    let cv = self.constants.find(&arg);
126                    cv.map_or(false, |cv| cv.r#type != Type::Type_Unknown)
127                } {
128                    let cv = *self.constants.find(&arg).unwrap();
129                    args.push(InlineArg {
130                        local: var,
131                        reg: K_INVALID_REG,
132                        value: cv,
133                        allocpc: K_DEFAULT_ALLOC_PC,
134                        init: core::ptr::null_mut(),
135                    });
136                } else {
137                    let le = self.get_expr_local(arg);
138                    let lv_written = if !le.is_null() {
139                        self.variables.find(&(*le).local).map(|v| v.written)
140                    } else {
141                        None
142                    };
143                    let reg: i32 = if !le.is_null() {
144                        self.get_expr_local_reg(le as *mut AstExpr)
145                    } else {
146                        -1
147                    };
148                    if reg >= 0 && (lv_written.is_none() || lv_written == Some(false)) {
149                        let lv_init = if !le.is_null() {
150                            self.variables
151                                .find(&(*le).local)
152                                .map_or(core::ptr::null_mut(), |v| v.init)
153                        } else {
154                            core::ptr::null_mut()
155                        };
156                        args.push(InlineArg {
157                            local: var,
158                            reg: reg as u8,
159                            value: Constant {
160                                r#type: Type::Type_Unknown,
161                                string_length: 0,
162                                data: ConstantData { value_number: 0.0 },
163                            },
164                            allocpc: K_DEFAULT_ALLOC_PC,
165                            init: lv_init,
166                        });
167                    } else {
168                        let temp = self.alloc_reg(arg as *mut AstNode, 1u32);
169                        let allocpc = (*self.bytecode).get_debug_pc();
170                        self.compile_expr_temp(arg, temp);
171                        args.push(InlineArg {
172                            local: var,
173                            reg: temp,
174                            value: Constant {
175                                r#type: Type::Type_Unknown,
176                                string_length: 0,
177                                data: ConstantData { value_number: 0.0 },
178                            },
179                            allocpc,
180                            init: arg,
181                        });
182                    }
183                }
184
185                i += 1;
186            }
187
188            // evaluate extra expressions for side effects
189            let mut k = func_args_size;
190            while k < expr_args_size {
191                let side = *(*expr).args.data.add(k);
192                self.compile_expr_side(side);
193                k += 1;
194            }
195
196            // apply all evaluated arguments to the compiler state
197            for arg in &args {
198                if arg.value.r#type == Type::Type_Unknown {
199                    self.push_local(arg.local, arg.reg, arg.allocpc);
200                    if !arg.init.is_null() {
201                        if let Some(lv) = self.variables.find_mut(&arg.local) {
202                            lv.init = arg.init;
203                        }
204                    }
205                } else {
206                    *self.locstants.get_or_insert(arg.local) = arg.value;
207                }
208            }
209
210            self.inline_frames.push(InlineFrame {
211                func,
212                local_offset: old_locals,
213                target,
214                target_count,
215                return_jumps: Vec::new(),
216            });
217
218            let func_body = (*func).body;
219
220            {
221                let ib = &mut self.inline_builtins as *mut _;
222                analyze_builtins(
223                    &mut *ib,
224                    &self.globals,
225                    &self.variables,
226                    &self.options,
227                    func_body as *mut AstNode,
228                    &*self.names,
229                );
230            }
231
232            if !self.inline_builtins.is_empty() {
233                let entries: Vec<(*mut AstExprCall, i32)> =
234                    self.inline_builtins.iter().map(|(k, v)| (*k, *v)).collect();
235                for (call_expr, bfid) in entries {
236                    let builtin = *self.builtins.get_or_insert(call_expr);
237                    if bfid != builtin {
238                        *self.inline_builtins_backup.get_or_insert(call_expr) = builtin;
239                        *self.builtins.get_or_insert(call_expr) = bfid;
240                    }
241                }
242                self.inline_builtins.clear();
243            }
244
245            let record_changes = luaur_common::FFlag::LuauCompilePropagateTableProps2.get()
246                && luaur_common::FFlag::LuauCompileFoldOptimize.get();
247
248            if record_changes {
249                self.expr_changes.clear();
250                self.local_changes.clear();
251            }
252
253            fold_constants(
254                &mut self.constants,
255                &mut self.variables,
256                &mut self.locstants,
257                self.builtins_fold,
258                self.builtins_fold_library_k,
259                self.options.library_member_constant_cb,
260                func_body as *mut AstNode,
261                &mut *self.names,
262                &self.table_constants,
263                &mut self.expr_changes as *mut _,
264                &mut self.local_changes as *mut _,
265            );
266
267            let mut terminates_early = false;
268            let body_size = (*func_body).body.size;
269            let mut bi = 0usize;
270            while bi < body_size {
271                let stat: *mut AstStat = *(*func_body).body.data.add(bi);
272                self.compile_stat(stat);
273                if self.always_terminates(stat) {
274                    terminates_early = true;
275                    let curr_frame = self.inline_frames.last_mut().unwrap();
276                    if !curr_frame.return_jumps.is_empty() {
277                        let last_jump = *curr_frame.return_jumps.last().unwrap();
278                        if last_jump == (*self.bytecode).emit_label() - 1 {
279                            (*self.bytecode).undo_emit(LuauOpcode::LOP_JUMP);
280                            curr_frame.return_jumps.pop();
281                        }
282                    }
283                    break;
284                }
285                bi += 1;
286            }
287
288            if !terminates_early {
289                let mut t = 0usize;
290                while t < target_count as usize {
291                    (*self.bytecode).emit_abc(LuauOpcode::LOP_LOADNIL, target + t as u8, 0, 0);
292                    t += 1;
293                }
294                self.close_locals(old_locals);
295            }
296
297            self.pop_locals(old_locals);
298            let return_label = (*self.bytecode).emit_label();
299            let rj = &mut self.inline_frames.last_mut().unwrap().return_jumps as *mut Vec<usize>;
300            self.patch_jumps(expr as *mut AstNode, &mut *rj, return_label);
301            self.inline_frames.pop();
302
303            // clean up constant state for future inlining attempts
304            let mut ci = 0usize;
305            while ci < func_args_size {
306                let local: *mut AstLocal = *(*func).args.data.add(ci);
307                if let Some(var) = self.locstants.find_mut(&local) {
308                    var.r#type = Type::Type_Unknown;
309                }
310                if let Some(lv) = self.variables.find_mut(&local) {
311                    lv.init = core::ptr::null_mut();
312                }
313                ci += 1;
314            }
315
316            if !self.inline_builtins_backup.is_empty() {
317                let entries: Vec<(*mut AstExprCall, i32)> = self
318                    .inline_builtins_backup
319                    .iter()
320                    .map(|(k, v)| (*k, *v))
321                    .collect();
322                for (call_expr, bfid) in entries {
323                    *self.builtins.get_or_insert(call_expr) = bfid;
324                }
325                self.inline_builtins_backup.clear();
326            }
327
328            if record_changes {
329                undo_changes_dense_hash_map_ast_expr_constant_expr_constant_change_log(
330                    &mut self.constants,
331                    &self.expr_changes,
332                );
333                undo_changes_dense_hash_map_ast_local_constant_local_constant_change_log(
334                    &mut self.locstants,
335                    &self.local_changes,
336                );
337            } else {
338                fold_constants(
339                    &mut self.constants,
340                    &mut self.variables,
341                    &mut self.locstants,
342                    self.builtins_fold,
343                    self.builtins_fold_library_k,
344                    self.options.library_member_constant_cb,
345                    func_body as *mut AstNode,
346                    &mut *self.names,
347                    &self.table_constants,
348                    core::ptr::null_mut(),
349                    core::ptr::null_mut(),
350                );
351            }
352        }
353    }
354}