Skip to main content

luaur_code_gen/functions/
translate_inst_binary_numeric.rs

1use crate::enums::ir_cmd::IrCmd;
2use crate::enums::ir_op_kind::IrOpKind;
3use crate::enums::ir_op_kind::IrOpKind::*;
4use crate::functions::get_initialized_fallback::get_initialized_fallback;
5use crate::functions::is_userdata_bytecode_type::is_userdata_bytecode_type;
6use crate::functions::load_double_or_constant::load_double_or_constant;
7use crate::functions::translate_binary_numeric_fallback_if_required::translate_binary_numeric_fallback_if_required;
8use crate::macros::codegen_assert::CODEGEN_ASSERT;
9use crate::records::bytecode_types::BytecodeTypes;
10use crate::records::ir_builder::IrBuilder;
11use crate::records::ir_function::IrFunction;
12use crate::records::ir_op::IrOp;
13use crate::type_aliases::instruction_ir_builder::Instruction;
14use luaur_common::enums::luau_builtin_function::LuauBuiltinFunction;
15use luaur_common::enums::luau_bytecode_type::LuauBytecodeType;
16use luaur_vm::enums::lua_type::lua_Type;
17use luaur_vm::type_aliases::tms::TMS;
18
19pub fn translate_inst_binary_numeric(
20    build: &mut IrBuilder,
21    ra: i32,
22    rb: i32,
23    rc: i32,
24    opb: IrOp,
25    opc: IrOp,
26    pcpos: i32,
27    tm: TMS,
28) {
29    let mut fallback = IrOp::ir_op();
30
31    let bc_types = build.function.get_bytecode_types_at(pcpos);
32
33    // Special fast-paths for vectors, matching the cases we have in VM
34    if bc_types.a == LuauBytecodeType::LBC_TYPE_VECTOR.0 as u8
35        && bc_types.b == LuauBytecodeType::LBC_TYPE_VECTOR.0 as u8
36        && (tm == TMS::TM_ADD
37            || tm == TMS::TM_SUB
38            || tm == TMS::TM_MUL
39            || tm == TMS::TM_DIV
40            || tm == TMS::TM_IDIV)
41    {
42        let reg_rb = build.vm_reg(rb as u8);
43        let tag_b = build.inst_ir_cmd_ir_op(IrCmd::LOAD_TAG, reg_rb);
44        let vector_tag = build.const_tag(lua_Type::LUA_TVECTOR as u8);
45        let exit = build.vm_exit(pcpos as u32);
46        build.inst_ir_cmd_ir_op_ir_op_ir_op(IrCmd::CHECK_TAG, tag_b, vector_tag, exit);
47
48        let reg_rc = build.vm_reg(rc as u8);
49        let tag_c = build.inst_ir_cmd_ir_op(IrCmd::LOAD_TAG, reg_rc);
50        let vector_tag = build.const_tag(lua_Type::LUA_TVECTOR as u8);
51        let exit = build.vm_exit(pcpos as u32);
52        build.inst_ir_cmd_ir_op_ir_op_ir_op(IrCmd::CHECK_TAG, tag_c, vector_tag, exit);
53
54        let vb = build.inst_ir_cmd_ir_op_ir_op(IrCmd::LOAD_TVALUE, opb, IrOp::ir_op());
55        let vb = vb; // keep naming aligned with original
56
57        let vc = build.inst_ir_cmd_ir_op_ir_op(IrCmd::LOAD_TVALUE, opc, IrOp::ir_op());
58        let vc = vc;
59
60        let result = match tm {
61            TMS::TM_ADD => build.inst_ir_cmd_ir_op_ir_op(IrCmd::ADD_VEC, vb, vc),
62            TMS::TM_SUB => build.inst_ir_cmd_ir_op_ir_op(IrCmd::SUB_VEC, vb, vc),
63            TMS::TM_MUL => build.inst_ir_cmd_ir_op_ir_op(IrCmd::MUL_VEC, vb, vc),
64            TMS::TM_DIV => build.inst_ir_cmd_ir_op_ir_op(IrCmd::DIV_VEC, vb, vc),
65            TMS::TM_IDIV => build.inst_ir_cmd_ir_op_ir_op(IrCmd::IDIV_VEC, vb, vc),
66            _ => {
67                CODEGEN_ASSERT!(false);
68                IrOp::ir_op()
69            }
70        };
71
72        let result = build.inst_ir_cmd_ir_op_ir_op(IrCmd::TAG_VECTOR, result, IrOp::ir_op());
73        let reg_ra = build.vm_reg(ra as u8);
74        build.inst_ir_cmd_ir_op_ir_op(IrCmd::STORE_TVALUE, reg_ra, result);
75        return;
76    } else if !is_userdata_bytecode_type(bc_types.a)
77        && bc_types.b == LuauBytecodeType::LBC_TYPE_VECTOR.0 as u8
78        && (tm == TMS::TM_MUL || tm == TMS::TM_DIV || tm == TMS::TM_IDIV)
79    {
80        if rb != -1 {
81            let fallback_exit = if bc_types.a == LuauBytecodeType::LBC_TYPE_NUMBER.0 as u8 {
82                build.vm_exit(pcpos as u32)
83            } else {
84                get_initialized_fallback(build, &mut fallback, pcpos)
85            };
86
87            let rb_reg = build.vm_reg(rb as u8);
88            let tag_load = build.inst_ir_cmd_ir_op(IrCmd::LOAD_TAG, rb_reg);
89            let number_tag = build.const_tag(lua_Type::LUA_TNUMBER as u8);
90            build.inst_ir_cmd_ir_op_ir_op_ir_op(
91                IrCmd::CHECK_TAG,
92                tag_load,
93                number_tag,
94                fallback_exit,
95            );
96        }
97
98        let reg_rc = build.vm_reg(rc as u8);
99        let tag_rc = build.inst_ir_cmd_ir_op(IrCmd::LOAD_TAG, reg_rc);
100        let vector_tag = build.const_tag(lua_Type::LUA_TVECTOR as u8);
101        let exit = build.vm_exit(pcpos as u32);
102        build.inst_ir_cmd_ir_op_ir_op_ir_op(IrCmd::CHECK_TAG, tag_rc, vector_tag, exit);
103
104        let load_d = load_double_or_constant(build, opb);
105        let undef = IrOp::ir_op();
106        let num_float = build.inst_ir_cmd_ir_op_ir_op(IrCmd::NUM_TO_FLOAT, load_d, undef);
107        let vb = build.inst_ir_cmd_ir_op_ir_op(IrCmd::FLOAT_TO_VEC, num_float, IrOp::ir_op());
108
109        let vc = build.inst_ir_cmd_ir_op(IrCmd::LOAD_TVALUE, opc);
110        let result = match tm {
111            TMS::TM_MUL => build.inst_ir_cmd_ir_op_ir_op(IrCmd::MUL_VEC, vb, vc),
112            TMS::TM_DIV => build.inst_ir_cmd_ir_op_ir_op(IrCmd::DIV_VEC, vb, vc),
113            TMS::TM_IDIV => build.inst_ir_cmd_ir_op_ir_op(IrCmd::IDIV_VEC, vb, vc),
114            _ => {
115                CODEGEN_ASSERT!(false);
116                IrOp::ir_op()
117            }
118        };
119
120        let result = build.inst_ir_cmd_ir_op_ir_op(IrCmd::TAG_VECTOR, result, IrOp::ir_op());
121        let reg_ra = build.vm_reg(ra as u8);
122        build.inst_ir_cmd_ir_op_ir_op(IrCmd::STORE_TVALUE, reg_ra, result);
123
124        translate_binary_numeric_fallback_if_required(build, fallback, ra, opb, opc, tm, pcpos);
125        return;
126    } else if bc_types.a == LuauBytecodeType::LBC_TYPE_VECTOR.0 as u8
127        && !is_userdata_bytecode_type(bc_types.b)
128        && (tm == TMS::TM_MUL || tm == TMS::TM_DIV || tm == TMS::TM_IDIV)
129    {
130        let reg_rb = build.vm_reg(rb as u8);
131        let tag_rb = build.inst_ir_cmd_ir_op(IrCmd::LOAD_TAG, reg_rb);
132        let vector_tag = build.const_tag(lua_Type::LUA_TVECTOR as u8);
133        let exit = build.vm_exit(pcpos as u32);
134        build.inst_ir_cmd_ir_op_ir_op_ir_op(IrCmd::CHECK_TAG, tag_rb, vector_tag, exit);
135
136        if rc != -1 {
137            let fallback_exit = if bc_types.b == LuauBytecodeType::LBC_TYPE_NUMBER.0 as u8 {
138                build.vm_exit(pcpos as u32)
139            } else {
140                get_initialized_fallback(build, &mut fallback, pcpos)
141            };
142
143            let rc_reg = build.vm_reg(rc as u8);
144            let tag_load = build.inst_ir_cmd_ir_op(IrCmd::LOAD_TAG, rc_reg);
145            let number_tag = build.const_tag(lua_Type::LUA_TNUMBER as u8);
146            build.inst_ir_cmd_ir_op_ir_op_ir_op(
147                IrCmd::CHECK_TAG,
148                tag_load,
149                number_tag,
150                fallback_exit,
151            );
152        }
153
154        let vb = build.inst_ir_cmd_ir_op(IrCmd::LOAD_TVALUE, opb);
155        let load_d = load_double_or_constant(build, opc);
156        let num_float = build.inst_ir_cmd_ir_op_ir_op(IrCmd::NUM_TO_FLOAT, load_d, IrOp::ir_op());
157        let vc = build.inst_ir_cmd_ir_op_ir_op(IrCmd::FLOAT_TO_VEC, num_float, IrOp::ir_op());
158
159        let result = match tm {
160            TMS::TM_MUL => build.inst_ir_cmd_ir_op_ir_op(IrCmd::MUL_VEC, vb, vc),
161            TMS::TM_DIV => build.inst_ir_cmd_ir_op_ir_op(IrCmd::DIV_VEC, vb, vc),
162            TMS::TM_IDIV => build.inst_ir_cmd_ir_op_ir_op(IrCmd::IDIV_VEC, vb, vc),
163            _ => {
164                CODEGEN_ASSERT!(false);
165                IrOp::ir_op()
166            }
167        };
168
169        let result = build.inst_ir_cmd_ir_op_ir_op(IrCmd::TAG_VECTOR, result, IrOp::ir_op());
170        let reg_ra = build.vm_reg(ra as u8);
171        build.inst_ir_cmd_ir_op_ir_op(IrCmd::STORE_TVALUE, reg_ra, result);
172
173        translate_binary_numeric_fallback_if_required(build, fallback, ra, opb, opc, tm, pcpos);
174        return;
175    }
176
177    if is_userdata_bytecode_type(bc_types.a) || is_userdata_bytecode_type(bc_types.b) {
178        let savedpc = build.const_uint((pcpos + 1) as u32);
179        build.inst_ir_cmd_ir_op(IrCmd::SET_SAVEDPC, savedpc);
180        let reg_ra = build.vm_reg(ra as u8);
181        let tm_op = build.const_int(tm as i32);
182        build.inst_ir_cmd_ir_op_ir_op_ir_op_ir_op(IrCmd::DO_ARITH, reg_ra, opb, opc, tm_op);
183        return;
184    }
185
186    // fast-path: number
187    if rb != -1 {
188        let reg_rb = build.vm_reg(rb as u8);
189        let tb = build.inst_ir_cmd_ir_op(IrCmd::LOAD_TAG, reg_rb);
190        let exit_or_fallback = if bc_types.a == LuauBytecodeType::LBC_TYPE_NUMBER.0 as u8 {
191            build.vm_exit(pcpos as u32)
192        } else {
193            get_initialized_fallback(build, &mut fallback, pcpos)
194        };
195        let number_tag = build.const_tag(lua_Type::LUA_TNUMBER as u8);
196        build.inst_ir_cmd_ir_op_ir_op_ir_op(IrCmd::CHECK_TAG, tb, number_tag, exit_or_fallback);
197    }
198
199    if rc != -1 && rc != rb {
200        let reg_rc = build.vm_reg(rc as u8);
201        let tc = build.inst_ir_cmd_ir_op(IrCmd::LOAD_TAG, reg_rc);
202        let exit_or_fallback = if bc_types.b == LuauBytecodeType::LBC_TYPE_NUMBER.0 as u8 {
203            build.vm_exit(pcpos as u32)
204        } else {
205            get_initialized_fallback(build, &mut fallback, pcpos)
206        };
207        let number_tag = build.const_tag(lua_Type::LUA_TNUMBER as u8);
208        build.inst_ir_cmd_ir_op_ir_op_ir_op(IrCmd::CHECK_TAG, tc, number_tag, exit_or_fallback);
209    }
210
211    let vb = load_double_or_constant(build, opb);
212    let mut vc = IrOp::ir_op();
213    let mut result = IrOp::ir_op();
214
215    if opc.kind() == IrOpKind::VmConst {
216        let protok_index = crate::functions::vm_const_op::vm_const_op(opc);
217        CODEGEN_ASSERT!(build.function.proto.is_null() == false);
218        let protok = unsafe { *(*build.function.proto).k.add(protok_index as usize) };
219        CODEGEN_ASSERT!(protok.tt == lua_Type::LUA_TNUMBER as i32);
220
221        let n = unsafe { protok.value.n };
222        if tm == TMS::TM_POW && n == 0.5 {
223            result = build.inst_ir_cmd_ir_op(IrCmd::SQRT_NUM, vb);
224        } else if tm == TMS::TM_POW && n == 2.0 {
225            result = build.inst_ir_cmd_ir_op_ir_op(IrCmd::MUL_NUM, vb, vb);
226        } else if tm == TMS::TM_POW && n == 3.0 {
227            let vv = build.inst_ir_cmd_ir_op_ir_op(IrCmd::MUL_NUM, vb, vb);
228            result = build.inst_ir_cmd_ir_op_ir_op(IrCmd::MUL_NUM, vb, vv);
229        } else {
230            vc = build.const_double(n);
231        }
232    } else {
233        vc = build.inst_ir_cmd_ir_op(IrCmd::LOAD_DOUBLE, opc);
234    }
235
236    // If result is None, we need to emit the generic numeric op
237    if result.kind() == IrOpKind::None {
238        CODEGEN_ASSERT!(vc.kind() != IrOpKind::None);
239        result = match tm {
240            TMS::TM_ADD => build.inst_ir_cmd_ir_op_ir_op(IrCmd::ADD_NUM, vb, vc),
241            TMS::TM_SUB => build.inst_ir_cmd_ir_op_ir_op(IrCmd::SUB_NUM, vb, vc),
242            TMS::TM_MUL => build.inst_ir_cmd_ir_op_ir_op(IrCmd::MUL_NUM, vb, vc),
243            TMS::TM_DIV => build.inst_ir_cmd_ir_op_ir_op(IrCmd::DIV_NUM, vb, vc),
244            TMS::TM_IDIV => build.inst_ir_cmd_ir_op_ir_op(IrCmd::IDIV_NUM, vb, vc),
245            TMS::TM_MOD => build.inst_ir_cmd_ir_op_ir_op(IrCmd::MOD_NUM, vb, vc),
246            TMS::TM_POW => {
247                let pow = build.const_uint(LuauBuiltinFunction::LBF_MATH_POW as u32);
248                build.inst_ir_cmd_ir_op_ir_op_ir_op(IrCmd::INVOKE_LIBM, pow, vb, vc)
249            }
250            _ => {
251                CODEGEN_ASSERT!(false);
252                IrOp::ir_op()
253            }
254        };
255    }
256
257    let reg_ra = build.vm_reg(ra as u8);
258    build.inst_ir_cmd_ir_op_ir_op(IrCmd::STORE_DOUBLE, reg_ra, result);
259
260    if ra != rb && ra != rc {
261        let reg_ra = build.vm_reg(ra as u8);
262        let number_tag = build.const_tag(lua_Type::LUA_TNUMBER as u8);
263        build.inst_ir_cmd_ir_op_ir_op(IrCmd::STORE_TAG, reg_ra, number_tag);
264    }
265
266    translate_binary_numeric_fallback_if_required(build, fallback, ra, opb, opc, tm, pcpos);
267}