Skip to main content

luaur_compiler/records/
shape_visitor.rs

1//! Source: `Compiler/src/TableShape.cpp:27-149`
2
3use crate::records::hasher::Hasher;
4use crate::records::table_shape::TableShape;
5use luaur_ast::records::ast_expr::AstExpr;
6use luaur_ast::records::ast_expr_constant_number::AstExprConstantNumber;
7use luaur_ast::records::ast_expr_index_expr::AstExprIndexExpr;
8use luaur_ast::records::ast_expr_index_name::AstExprIndexName;
9use luaur_ast::records::ast_expr_local::AstExprLocal;
10use luaur_ast::records::ast_expr_table::AstExprTable;
11use luaur_ast::records::ast_local::AstLocal;
12use luaur_ast::records::ast_name::AstName;
13use luaur_ast::records::ast_stat_assign::AstStatAssign;
14use luaur_ast::records::ast_stat_for::AstStatFor;
15use luaur_ast::records::ast_stat_function::AstStatFunction;
16use luaur_ast::records::ast_stat_local::AstStatLocal;
17use luaur_ast::records::ast_visitor::AstVisitor;
18use luaur_common::records::dense_hash_map::DenseHashMap;
19use luaur_common::records::dense_hash_set::DenseHashSet;
20
21#[allow(non_camel_case_types)]
22#[derive(Debug)]
23pub struct ShapeVisitor<'a> {
24    pub(crate) shapes: &'a mut DenseHashMap<*mut AstExprTable, TableShape>,
25    pub(crate) tables: DenseHashMap<*mut AstLocal, *mut AstExprTable>,
26    pub(crate) fields: DenseHashSet<(*mut AstExprTable, AstName), Hasher>,
27    pub(crate) loops: DenseHashMap<*mut AstLocal, core::ffi::c_uint>,
28}
29
30impl<'a> ShapeVisitor<'a> {
31    pub fn new(shapes: &'a mut DenseHashMap<*mut AstExprTable, TableShape>) -> Self {
32        ShapeVisitor {
33            shapes,
34            tables: DenseHashMap::new(core::ptr::null_mut()),
35            fields: DenseHashSet::new((core::ptr::null_mut(), AstName::new())),
36            loops: DenseHashMap::new(core::ptr::null_mut()),
37        }
38    }
39
40    fn assign_field_name(&mut self, expr: *mut AstExpr, index: AstName) {
41        let lv = unsafe {
42            luaur_ast::rtti::ast_node_as::<AstExprLocal>(
43                expr as *mut luaur_ast::records::ast_node::AstNode,
44            )
45        };
46        if lv.is_null() {
47            return;
48        }
49
50        let table_opt = self.tables.find(&unsafe { (*lv).local });
51        if let Some(&table) = table_opt {
52            let field = (table, index);
53
54            if !self.fields.contains(&field) {
55                self.fields.insert(field);
56                // C++ `shapes[*table].hashSize += 1` — operator[] inserts a default
57                // shape on miss. `find_mut` does NOT insert, so the FIRST field of
58                // every table found no shape and never counted -> predictions 0.
59                self.shapes.get_or_insert(table).hash_size += 1;
60            }
61        }
62    }
63
64    fn assign_field_expr(&mut self, expr: *mut AstExpr, index: *mut AstExpr) {
65        let lv = unsafe {
66            luaur_ast::rtti::ast_node_as::<AstExprLocal>(
67                expr as *mut luaur_ast::records::ast_node::AstNode,
68            )
69        };
70        if lv.is_null() {
71            return;
72        }
73
74        let table_opt = self.tables.find(&unsafe { (*lv).local });
75        let table = match table_opt {
76            Some(t) => *t,
77            None => return,
78        };
79
80        let number = unsafe {
81            luaur_ast::rtti::ast_node_as::<AstExprConstantNumber>(
82                index as *mut luaur_ast::records::ast_node::AstNode,
83            )
84        };
85        if !number.is_null() {
86            // C++ `shapes[*table]` inserts-on-miss; `find_mut` did not, so array
87            // predictions never started.
88            let shape = self.shapes.get_or_insert(table);
89            if unsafe { (*number).value } == (shape.array_size as f64 + 1.0) {
90                shape.array_size += 1;
91            }
92        } else {
93            let iter = unsafe {
94                luaur_ast::rtti::ast_node_as::<AstExprLocal>(
95                    index as *mut luaur_ast::records::ast_node::AstNode,
96                )
97            };
98            if !iter.is_null() {
99                if let Some(&bound) = self.loops.find(&unsafe { (*iter).local }) {
100                    let shape = self.shapes.get_or_insert(table);
101                    if shape.array_size == 0 {
102                        shape.array_size = bound;
103                    }
104                }
105            }
106        }
107    }
108
109    fn assign(&mut self, var: *mut AstExpr) {
110        let index_name = unsafe {
111            luaur_ast::rtti::ast_node_as::<AstExprIndexName>(
112                var as *mut luaur_ast::records::ast_node::AstNode,
113            )
114        };
115        if !index_name.is_null() {
116            self.assign_field_name(unsafe { (*index_name).expr }, unsafe {
117                (*index_name).index
118            });
119            return;
120        }
121
122        let index_expr = unsafe {
123            luaur_ast::rtti::ast_node_as::<AstExprIndexExpr>(
124                var as *mut luaur_ast::records::ast_node::AstNode,
125            )
126        };
127        if !index_expr.is_null() {
128            self.assign_field_expr(unsafe { (*index_expr).expr }, unsafe {
129                (*index_expr).index
130            });
131        }
132    }
133}
134
135impl<'a> AstVisitor for ShapeVisitor<'a> {
136    fn visit_stat_local(&mut self, node: *mut core::ffi::c_void) -> bool {
137        let node = unsafe { &mut *(node as *mut AstStatLocal) };
138
139        if node.vars.size == 1 && node.values.size == 1 {
140            let value = unsafe { *node.values.data.add(0) };
141            // C++ uses getTableHint, which unwraps `setmetatable(table_literal, ...)` to the
142            // inner table literal. Casting the initializer straight to AstExprTable missed
143            // that form, so a table behind setmetatable was never tracked and its predicted
144            // shape stayed (0,0) -> NEWTABLE with size 0.
145            let table = crate::functions::get_table_hint::get_table_hint(value);
146            if !table.is_null() && unsafe { (*table).items.size } == 0 {
147                let var = unsafe { *node.vars.data.add(0) };
148                self.tables.try_insert(var, table);
149            }
150        }
151
152        true
153    }
154
155    fn visit_stat_assign(&mut self, node: *mut core::ffi::c_void) -> bool {
156        let node = unsafe { &mut *(node as *mut AstStatAssign) };
157
158        for i in 0..node.vars.size as usize {
159            let var = unsafe { *node.vars.data.add(i) };
160            self.assign(var);
161        }
162
163        for i in 0..node.values.size as usize {
164            let value = unsafe { *node.values.data.add(i) };
165            unsafe { luaur_ast::visit::ast_expr_visit(value, self as &mut dyn AstVisitor) };
166        }
167
168        false
169    }
170
171    fn visit_stat_function(&mut self, node: *mut core::ffi::c_void) -> bool {
172        let node = unsafe { &mut *(node as *mut AstStatFunction) };
173
174        self.assign(node.name);
175
176        unsafe {
177            luaur_ast::visit::ast_expr_visit(
178                node.func as *mut luaur_ast::records::ast_expr::AstExpr,
179                self as &mut dyn AstVisitor,
180            )
181        };
182
183        false
184    }
185
186    fn visit_stat_for(&mut self, node: *mut core::ffi::c_void) -> bool {
187        let node = unsafe { &mut *(node as *mut AstStatFor) };
188
189        let from = unsafe {
190            luaur_ast::rtti::ast_node_as::<AstExprConstantNumber>(
191                node.from as *mut luaur_ast::records::ast_node::AstNode,
192            )
193        };
194        let to = unsafe {
195            luaur_ast::rtti::ast_node_as::<AstExprConstantNumber>(
196                node.to as *mut luaur_ast::records::ast_node::AstNode,
197            )
198        };
199
200        if !from.is_null() && !to.is_null() {
201            let from_val = unsafe { (*from).value };
202            let to_val = unsafe { (*to).value };
203
204            if from_val == 1.0 && to_val >= 1.0 && to_val <= 16.0 && node.step.is_null() {
205                self.loops.try_insert(node.var, to_val as core::ffi::c_uint);
206            }
207        }
208
209        true
210    }
211}