Skip to main content

luaur_compiler/methods/
compiler_compile_expr_interp_string.rs

1use crate::enums::type_constant_folding::Type;
2use crate::functions::escape_and_append::escapeAndAppend;
3use crate::functions::sref_compiler::sref_ast_name;
4use crate::functions::sref_compiler_alt_c::sref_ast_array_c_char;
5use crate::records::compile_error::CompileError;
6use crate::records::compiler::Compiler;
7use crate::records::constant::Constant;
8use crate::records::reg_scope::RegScope;
9use alloc::vec::Vec;
10use core::ffi::c_char;
11use luaur_ast::records::ast_array::AstArray;
12use luaur_ast::records::ast_expr_interp_string::AstExprInterpString;
13use luaur_ast::records::ast_name::AstName;
14use luaur_bytecode::methods::bytecode_builder_get_string_hash::bytecode_builder_get_string_hash;
15use luaur_common::enums::luau_opcode::LuauOpcode;
16
17impl Compiler {
18    pub fn compile_expr_interp_string(
19        &mut self,
20        expr: *mut AstExprInterpString,
21        target: u8,
22        target_temp: bool,
23    ) {
24        unsafe {
25            let expr_ref = &*expr;
26            let mut format_capacity = 0;
27            for string in expr_ref.strings.iter() {
28                format_capacity +=
29                    (*string).size + (*string).iter().filter(|&&c| c == b'%' as i8).count();
30            }
31
32            let mut skipped_sub_expr = 0;
33            for i in 0..expr_ref.expressions.size {
34                let sub_expr = *expr_ref.expressions.data.add(i);
35                if let Some(c) = self.constants.find(&sub_expr) {
36                    if c.r#type == Type::Type_String {
37                        format_capacity += c.string_length as usize
38                            + c.get_string().iter().filter(|&&c| c == b'%' as i8).count();
39                        skipped_sub_expr += 1;
40                    } else {
41                        format_capacity += 2;
42                    }
43                } else {
44                    format_capacity += 2;
45                }
46            }
47
48            let mut format_string = Vec::with_capacity(format_capacity);
49            for i in 0..expr_ref.strings.size {
50                let string = *expr_ref.strings.data.add(i);
51                escapeAndAppend(&mut format_string, string.data, string.size);
52                if i < expr_ref.expressions.size {
53                    let sub_expr = *expr_ref.expressions.data.add(i);
54                    if let Some(c) = self.constants.find(&sub_expr) {
55                        if c.r#type == Type::Type_String {
56                            escapeAndAppend(
57                                &mut format_string,
58                                c.get_string().data,
59                                c.string_length as usize,
60                            );
61                        } else {
62                            format_string.extend_from_slice(b"%*");
63                        }
64                    } else {
65                        format_string.extend_from_slice(b"%*");
66                    }
67                }
68            }
69
70            let format_string_index = if format_string.is_empty() {
71                let interned = (*self.names).get_or_add(c"".as_ptr(), 0);
72                (*self.bytecode).add_constant_string(sref_ast_name(interned))
73            } else {
74                let interned = (*self.names)
75                    .get_or_add(format_string.as_ptr() as *const c_char, format_string.len());
76                let format_string_array = AstArray {
77                    data: interned.value as *mut c_char,
78                    size: format_string.len(),
79                };
80                (*self.bytecode).add_constant_string(sref_ast_array_c_char(format_string_array))
81            };
82
83            if format_string_index < 0 {
84                CompileError::raise(
85                    &expr_ref.base.base.location,
86                    format_args!("Exceeded constant limit; simplify the code to compile"),
87                );
88            }
89
90            let mut rs = self.reg_scope_compiler();
91            let reg_count = 2 + expr_ref.expressions.size - skipped_sub_expr;
92            let target_top = luaur_common::FFlag::LuauCompileStringInterpTargetTop.get()
93                && target_temp
94                && target as u32 == self.reg_top - 1;
95            let base_reg = if target_top {
96                self.alloc_reg(expr as *mut _, (reg_count - 1) as u32) - 1
97            } else {
98                self.alloc_reg(expr as *mut _, reg_count as u32)
99            };
100
101            self.emit_load_k(base_reg, format_string_index);
102
103            let mut skipped = 0;
104            for i in 0..expr_ref.expressions.size {
105                let sub_expr = *expr_ref.expressions.data.add(i);
106                if self
107                    .constants
108                    .find(&sub_expr)
109                    .map_or(true, |c| c.r#type != Type::Type_String)
110                {
111                    self.compile_expr_temp_top(sub_expr, base_reg + 2 + i as u8 - skipped as u8);
112                } else {
113                    skipped += 1;
114                }
115            }
116
117            let format_method = sref_ast_name(AstName::ast_name_c_char(c"format".as_ptr()));
118            let format_method_index = (*self.bytecode).add_constant_string(format_method);
119            if format_method_index < 0 {
120                CompileError::raise(
121                    &expr_ref.base.base.location,
122                    format_args!("Exceeded constant limit; simplify the code to compile"),
123                );
124            }
125
126            (*self.bytecode).emit_abc(
127                LuauOpcode::LOP_NAMECALL,
128                base_reg,
129                base_reg,
130                bytecode_builder_get_string_hash(format_method) as u8,
131            );
132            (*self.bytecode).emit_aux(format_method_index as u32);
133            (*self.bytecode).emit_abc(
134                LuauOpcode::LOP_CALL,
135                base_reg,
136                (expr_ref.expressions.size + 2 - skipped_sub_expr) as u8,
137                2,
138            );
139            if target != base_reg {
140                (*self.bytecode).emit_abc(LuauOpcode::LOP_MOVE, target, base_reg, 0);
141            }
142        }
143    }
144}