Skip to main content

luaur_analysis/methods/
constraint_generator_check_constraint_generator_dispatcher.rs

1//! Hand-ported dispatcher: ConstraintGenerator::check(const ScopePtr&, AstExpr*, ...)
2//! Source: `Analysis/src/ConstraintGenerator.cpp:3042`.
3//!
4//! The top-level expression `check` overload. C++ overloads `check` on the
5//! static type of the expression; here we recover that with RTTI dispatch to
6//! the per-node `check_scope_ptr_ast_expr_*` methods. Two convenience entry
7//! points are provided to mirror the defaulted C++ arguments
8//! (`expectedType = {}`, `forceSingleton = false`, `generalize = true`).
9use crate::records::constraint_generator::ConstraintGenerator;
10use crate::records::inference::Inference;
11use crate::records::module::Module;
12use crate::type_aliases::scope_ptr_constraint_generator::ScopePtr;
13use crate::type_aliases::type_id::TypeId;
14use luaur_ast::records::ast_expr::AstExpr;
15use luaur_ast::records::ast_expr_binary::AstExprBinary;
16use luaur_ast::records::ast_expr_call::AstExprCall;
17use luaur_ast::records::ast_expr_constant_bool::AstExprConstantBool;
18use luaur_ast::records::ast_expr_constant_integer::AstExprConstantInteger;
19use luaur_ast::records::ast_expr_constant_nil::AstExprConstantNil;
20use luaur_ast::records::ast_expr_constant_number::AstExprConstantNumber;
21use luaur_ast::records::ast_expr_constant_string::AstExprConstantString;
22use luaur_ast::records::ast_expr_error::AstExprError;
23use luaur_ast::records::ast_expr_function::AstExprFunction;
24use luaur_ast::records::ast_expr_global::AstExprGlobal;
25use luaur_ast::records::ast_expr_group::AstExprGroup;
26use luaur_ast::records::ast_expr_if_else::AstExprIfElse;
27use luaur_ast::records::ast_expr_index_expr::AstExprIndexExpr;
28use luaur_ast::records::ast_expr_index_name::AstExprIndexName;
29use luaur_ast::records::ast_expr_instantiate::AstExprInstantiate;
30use luaur_ast::records::ast_expr_interp_string::AstExprInterpString;
31use luaur_ast::records::ast_expr_table::AstExprTable;
32use luaur_ast::records::ast_expr_type_assertion::AstExprTypeAssertion;
33use luaur_ast::records::ast_expr_unary::AstExprUnary;
34use luaur_ast::records::ast_expr_varargs::AstExprVarargs;
35use luaur_ast::records::ast_node::AstNode;
36use luaur_ast::rtti::ast_node_as;
37use luaur_common::DFInt;
38use luaur_common::LUAU_ASSERT;
39
40impl ConstraintGenerator {
41    /// Convenience: `check(scope, expr)` with all C++ defaults.
42    pub fn check_scope_ptr_ast_expr(&mut self, scope: &ScopePtr, expr: *mut AstExpr) -> Inference {
43        self.check_scope_ptr_ast_expr_optional_type_id_bool_bool(scope, expr, None, false, true)
44    }
45
46    /// Convenience: `check(scope, expr, expectedType)`.
47    pub fn check_scope_ptr_ast_expr_optional_type_id(
48        &mut self,
49        scope: &ScopePtr,
50        expr: *mut AstExpr,
51        expected_type: Option<TypeId>,
52    ) -> Inference {
53        self.check_scope_ptr_ast_expr_optional_type_id_bool_bool(
54            scope,
55            expr,
56            expected_type,
57            false,
58            true,
59        )
60    }
61
62    /// Convenience: `check(scope, expr, expectedType, forceSingleton)`.
63    pub fn check_scope_ptr_ast_expr_optional_type_id_bool(
64        &mut self,
65        scope: &ScopePtr,
66        expr: *mut AstExpr,
67        expected_type: Option<TypeId>,
68        force_singleton: bool,
69    ) -> Inference {
70        self.check_scope_ptr_ast_expr_optional_type_id_bool_bool(
71            scope,
72            expr,
73            expected_type,
74            force_singleton,
75            true,
76        )
77    }
78
79    pub fn check_scope_ptr_ast_expr_optional_type_id_bool_bool(
80        &mut self,
81        scope: &ScopePtr,
82        expr: *mut AstExpr,
83        expected_type: Option<TypeId>,
84        force_singleton: bool,
85        generalize: bool,
86    ) -> Inference {
87        // RecursionCounter counter{&recursionCount};
88        self.recursion_count += 1;
89        let result =
90            self.check_dispatch_impl(scope, expr, expected_type, force_singleton, generalize);
91        self.recursion_count -= 1;
92        result
93    }
94
95    fn check_dispatch_impl(
96        &mut self,
97        scope: &ScopePtr,
98        expr: *mut AstExpr,
99        expected_type: Option<TypeId>,
100        force_singleton: bool,
101        generalize: bool,
102    ) -> Inference {
103        unsafe {
104            if self.recursion_count >= DFInt::LuauConstraintGeneratorRecursionLimit.get() as i32 {
105                self.report_code_too_complex((*expr).base.location);
106                return Inference::inference_type_id_refinement_id(
107                    (*self.builtin_types).errorType,
108                    core::ptr::null_mut(),
109                );
110            }
111
112            // We may recurse a given expression more than once when checking
113            // compound assignment, so we store and cache expressions here.
114            if self.inferred_expr_cache.contains(&expr) {
115                return self.inferred_expr_cache.get_or_insert(expr).clone();
116            }
117
118            let node = expr as *mut AstNode;
119
120            let result: Inference = {
121                let group = ast_node_as::<AstExprGroup>(node);
122                if !group.is_null() {
123                    self.check_scope_ptr_ast_expr_optional_type_id_bool_bool(
124                        scope,
125                        (*group).expr,
126                        expected_type,
127                        force_singleton,
128                        generalize,
129                    )
130                } else if !ast_node_as::<AstExprConstantString>(node).is_null() {
131                    let string_expr = ast_node_as::<AstExprConstantString>(node);
132                    self.check_scope_ptr_ast_expr_constant_string_optional_type_id_bool(
133                        scope,
134                        string_expr,
135                        expected_type,
136                        force_singleton,
137                    )
138                } else if !ast_node_as::<AstExprConstantNumber>(node).is_null() {
139                    Inference::inference_type_id_refinement_id(
140                        (*self.builtin_types).numberType,
141                        core::ptr::null_mut(),
142                    )
143                } else if !ast_node_as::<AstExprConstantInteger>(node).is_null() {
144                    Inference::inference_type_id_refinement_id(
145                        (*self.builtin_types).integerType,
146                        core::ptr::null_mut(),
147                    )
148                } else if !ast_node_as::<AstExprConstantBool>(node).is_null() {
149                    let bool_expr = ast_node_as::<AstExprConstantBool>(node);
150                    self.check_scope_ptr_ast_expr_constant_bool_optional_type_id_bool(
151                        scope,
152                        bool_expr,
153                        expected_type,
154                        force_singleton,
155                    )
156                } else if !ast_node_as::<AstExprConstantNil>(node).is_null() {
157                    Inference::inference_type_id_refinement_id(
158                        (*self.builtin_types).nilType,
159                        core::ptr::null_mut(),
160                    )
161                } else if !ast_node_as::<AstExprLocal>(node).is_null() {
162                    let local = ast_node_as::<AstExprLocal>(node);
163                    self.check_scope_ptr_ast_expr_local(scope, local)
164                } else if !ast_node_as::<AstExprGlobal>(node).is_null() {
165                    let global = ast_node_as::<AstExprGlobal>(node);
166                    self.check_scope_ptr_ast_expr_global(scope, global)
167                } else if !ast_node_as::<AstExprVarargs>(node).is_null() {
168                    let pack = self.check_pack_scope_ptr_ast_expr_vector_optional_type_id_bool(
169                        scope,
170                        expr,
171                        &alloc::vec::Vec::new(),
172                        true,
173                    );
174                    self.flatten_pack(scope, (*expr).base.location, pack)
175                } else if !ast_node_as::<AstExprCall>(node).is_null() {
176                    let call = ast_node_as::<AstExprCall>(node);
177                    let pack = self.check_pack_scope_ptr_ast_expr_call(scope, call);
178                    self.flatten_pack(scope, (*expr).base.location, pack)
179                } else if !ast_node_as::<AstExprFunction>(node).is_null() {
180                    let func = ast_node_as::<AstExprFunction>(node);
181                    self.check_scope_ptr_ast_expr_function_optional_type_id_bool(
182                        scope,
183                        func,
184                        expected_type,
185                        generalize,
186                    )
187                } else if !ast_node_as::<AstExprIndexName>(node).is_null() {
188                    let index_name = ast_node_as::<AstExprIndexName>(node);
189                    self.check_scope_ptr_ast_expr_index_name(scope, index_name)
190                } else if !ast_node_as::<AstExprIndexExpr>(node).is_null() {
191                    let index_expr = ast_node_as::<AstExprIndexExpr>(node);
192                    self.check_scope_ptr_ast_expr_index_expr(scope, index_expr)
193                } else if !ast_node_as::<AstExprTable>(node).is_null() {
194                    let table = ast_node_as::<AstExprTable>(node);
195                    self.check_scope_ptr_ast_expr_table_optional_type_id(
196                        scope,
197                        table,
198                        expected_type,
199                    )
200                } else if !ast_node_as::<AstExprUnary>(node).is_null() {
201                    let unary = ast_node_as::<AstExprUnary>(node);
202                    self.check_scope_ptr_ast_expr_unary(scope, unary)
203                } else if !ast_node_as::<AstExprBinary>(node).is_null() {
204                    // C++: check(scope, binary, expectedType) returns the full
205                    // Inference (type + refinement). The
206                    // `check_scope_ptr_ast_expr_binary_optional_type_id` wrapper
207                    // discards the refinement (returns only `.ty`), so dispatch
208                    // to `check_ast_expr_binary` directly to stay faithful.
209                    let binary = ast_node_as::<AstExprBinary>(node);
210                    self.check_ast_expr_binary(
211                        scope,
212                        (*binary).base.base.location,
213                        (*binary).op,
214                        (*binary).left,
215                        (*binary).right,
216                        expected_type,
217                    )
218                } else if !ast_node_as::<AstExprIfElse>(node).is_null() {
219                    let if_else = ast_node_as::<AstExprIfElse>(node);
220                    self.check_scope_ptr_ast_expr_if_else_optional_type_id(
221                        scope,
222                        if_else,
223                        expected_type,
224                    )
225                } else if !ast_node_as::<AstExprTypeAssertion>(node).is_null() {
226                    let type_assert = ast_node_as::<AstExprTypeAssertion>(node);
227                    self.check_scope_ptr_ast_expr_type_assertion(scope, type_assert)
228                } else if !ast_node_as::<AstExprInterpString>(node).is_null() {
229                    let interp_string = ast_node_as::<AstExprInterpString>(node);
230                    self.check_scope_ptr_ast_expr_interp_string(scope, interp_string)
231                } else if !ast_node_as::<AstExprInstantiate>(node).is_null() {
232                    let instantiate = ast_node_as::<AstExprInstantiate>(node);
233                    self.check_scope_ptr_ast_expr_instantiate(scope, instantiate)
234                } else {
235                    let err = ast_node_as::<AstExprError>(node);
236                    if !err.is_null() {
237                        // Open question: Should we traverse into this?
238                        let expressions = (*err).expressions;
239                        for i in 0..expressions.size as usize {
240                            let sub_expr = *expressions.data.add(i);
241                            self.check_scope_ptr_ast_expr(scope, sub_expr);
242                        }
243                        Inference::inference_type_id_refinement_id(
244                            (*self.builtin_types).errorType,
245                            core::ptr::null_mut(),
246                        )
247                    } else {
248                        LUAU_ASSERT!(false);
249                        Inference::inference_type_id_refinement_id(
250                            self.fresh_type(scope, self.polarity),
251                            core::ptr::null_mut(),
252                        )
253                    }
254                }
255            };
256
257            *self.inferred_expr_cache.get_or_insert(expr) = result.clone();
258
259            LUAU_ASSERT!(!result.ty.is_null());
260
261            if let Some(module) = &self.module {
262                let module_ptr = alloc::sync::Arc::as_ptr(module) as *mut Module;
263                *(*module_ptr)
264                    .ast_types
265                    .get_or_insert(expr as *const AstExpr) = result.ty;
266                if let Some(et) = expected_type {
267                    *(*module_ptr)
268                        .ast_expected_types
269                        .get_or_insert(expr as *const AstExpr) = et;
270                }
271            }
272
273            result
274        }
275    }
276}
277
278use luaur_ast::records::ast_expr_local::AstExprLocal;