1use crate::records::constraint_generator::ConstraintGenerator;
10use crate::records::inference::Inference;
11use crate::records::module::Module;
12use crate::type_aliases::scope_ptr_constraint_generator::ScopePtr;
13use crate::type_aliases::type_id::TypeId;
14use luaur_ast::records::ast_expr::AstExpr;
15use luaur_ast::records::ast_expr_binary::AstExprBinary;
16use luaur_ast::records::ast_expr_call::AstExprCall;
17use luaur_ast::records::ast_expr_constant_bool::AstExprConstantBool;
18use luaur_ast::records::ast_expr_constant_integer::AstExprConstantInteger;
19use luaur_ast::records::ast_expr_constant_nil::AstExprConstantNil;
20use luaur_ast::records::ast_expr_constant_number::AstExprConstantNumber;
21use luaur_ast::records::ast_expr_constant_string::AstExprConstantString;
22use luaur_ast::records::ast_expr_error::AstExprError;
23use luaur_ast::records::ast_expr_function::AstExprFunction;
24use luaur_ast::records::ast_expr_global::AstExprGlobal;
25use luaur_ast::records::ast_expr_group::AstExprGroup;
26use luaur_ast::records::ast_expr_if_else::AstExprIfElse;
27use luaur_ast::records::ast_expr_index_expr::AstExprIndexExpr;
28use luaur_ast::records::ast_expr_index_name::AstExprIndexName;
29use luaur_ast::records::ast_expr_instantiate::AstExprInstantiate;
30use luaur_ast::records::ast_expr_interp_string::AstExprInterpString;
31use luaur_ast::records::ast_expr_table::AstExprTable;
32use luaur_ast::records::ast_expr_type_assertion::AstExprTypeAssertion;
33use luaur_ast::records::ast_expr_unary::AstExprUnary;
34use luaur_ast::records::ast_expr_varargs::AstExprVarargs;
35use luaur_ast::records::ast_node::AstNode;
36use luaur_ast::rtti::ast_node_as;
37use luaur_common::DFInt;
38use luaur_common::LUAU_ASSERT;
39
40impl ConstraintGenerator {
41 pub fn check_scope_ptr_ast_expr(&mut self, scope: &ScopePtr, expr: *mut AstExpr) -> Inference {
43 self.check_scope_ptr_ast_expr_optional_type_id_bool_bool(scope, expr, None, false, true)
44 }
45
46 pub fn check_scope_ptr_ast_expr_optional_type_id(
48 &mut self,
49 scope: &ScopePtr,
50 expr: *mut AstExpr,
51 expected_type: Option<TypeId>,
52 ) -> Inference {
53 self.check_scope_ptr_ast_expr_optional_type_id_bool_bool(
54 scope,
55 expr,
56 expected_type,
57 false,
58 true,
59 )
60 }
61
62 pub fn check_scope_ptr_ast_expr_optional_type_id_bool(
64 &mut self,
65 scope: &ScopePtr,
66 expr: *mut AstExpr,
67 expected_type: Option<TypeId>,
68 force_singleton: bool,
69 ) -> Inference {
70 self.check_scope_ptr_ast_expr_optional_type_id_bool_bool(
71 scope,
72 expr,
73 expected_type,
74 force_singleton,
75 true,
76 )
77 }
78
79 pub fn check_scope_ptr_ast_expr_optional_type_id_bool_bool(
80 &mut self,
81 scope: &ScopePtr,
82 expr: *mut AstExpr,
83 expected_type: Option<TypeId>,
84 force_singleton: bool,
85 generalize: bool,
86 ) -> Inference {
87 self.recursion_count += 1;
89 let result =
90 self.check_dispatch_impl(scope, expr, expected_type, force_singleton, generalize);
91 self.recursion_count -= 1;
92 result
93 }
94
95 fn check_dispatch_impl(
96 &mut self,
97 scope: &ScopePtr,
98 expr: *mut AstExpr,
99 expected_type: Option<TypeId>,
100 force_singleton: bool,
101 generalize: bool,
102 ) -> Inference {
103 unsafe {
104 if self.recursion_count >= DFInt::LuauConstraintGeneratorRecursionLimit.get() as i32 {
105 self.report_code_too_complex((*expr).base.location);
106 return Inference::inference_type_id_refinement_id(
107 (*self.builtin_types).errorType,
108 core::ptr::null_mut(),
109 );
110 }
111
112 if self.inferred_expr_cache.contains(&expr) {
115 return self.inferred_expr_cache.get_or_insert(expr).clone();
116 }
117
118 let node = expr as *mut AstNode;
119
120 let result: Inference = {
121 let group = ast_node_as::<AstExprGroup>(node);
122 if !group.is_null() {
123 self.check_scope_ptr_ast_expr_optional_type_id_bool_bool(
124 scope,
125 (*group).expr,
126 expected_type,
127 force_singleton,
128 generalize,
129 )
130 } else if !ast_node_as::<AstExprConstantString>(node).is_null() {
131 let string_expr = ast_node_as::<AstExprConstantString>(node);
132 self.check_scope_ptr_ast_expr_constant_string_optional_type_id_bool(
133 scope,
134 string_expr,
135 expected_type,
136 force_singleton,
137 )
138 } else if !ast_node_as::<AstExprConstantNumber>(node).is_null() {
139 Inference::inference_type_id_refinement_id(
140 (*self.builtin_types).numberType,
141 core::ptr::null_mut(),
142 )
143 } else if !ast_node_as::<AstExprConstantInteger>(node).is_null() {
144 Inference::inference_type_id_refinement_id(
145 (*self.builtin_types).integerType,
146 core::ptr::null_mut(),
147 )
148 } else if !ast_node_as::<AstExprConstantBool>(node).is_null() {
149 let bool_expr = ast_node_as::<AstExprConstantBool>(node);
150 self.check_scope_ptr_ast_expr_constant_bool_optional_type_id_bool(
151 scope,
152 bool_expr,
153 expected_type,
154 force_singleton,
155 )
156 } else if !ast_node_as::<AstExprConstantNil>(node).is_null() {
157 Inference::inference_type_id_refinement_id(
158 (*self.builtin_types).nilType,
159 core::ptr::null_mut(),
160 )
161 } else if !ast_node_as::<AstExprLocal>(node).is_null() {
162 let local = ast_node_as::<AstExprLocal>(node);
163 self.check_scope_ptr_ast_expr_local(scope, local)
164 } else if !ast_node_as::<AstExprGlobal>(node).is_null() {
165 let global = ast_node_as::<AstExprGlobal>(node);
166 self.check_scope_ptr_ast_expr_global(scope, global)
167 } else if !ast_node_as::<AstExprVarargs>(node).is_null() {
168 let pack = self.check_pack_scope_ptr_ast_expr_vector_optional_type_id_bool(
169 scope,
170 expr,
171 &alloc::vec::Vec::new(),
172 true,
173 );
174 self.flatten_pack(scope, (*expr).base.location, pack)
175 } else if !ast_node_as::<AstExprCall>(node).is_null() {
176 let call = ast_node_as::<AstExprCall>(node);
177 let pack = self.check_pack_scope_ptr_ast_expr_call(scope, call);
178 self.flatten_pack(scope, (*expr).base.location, pack)
179 } else if !ast_node_as::<AstExprFunction>(node).is_null() {
180 let func = ast_node_as::<AstExprFunction>(node);
181 self.check_scope_ptr_ast_expr_function_optional_type_id_bool(
182 scope,
183 func,
184 expected_type,
185 generalize,
186 )
187 } else if !ast_node_as::<AstExprIndexName>(node).is_null() {
188 let index_name = ast_node_as::<AstExprIndexName>(node);
189 self.check_scope_ptr_ast_expr_index_name(scope, index_name)
190 } else if !ast_node_as::<AstExprIndexExpr>(node).is_null() {
191 let index_expr = ast_node_as::<AstExprIndexExpr>(node);
192 self.check_scope_ptr_ast_expr_index_expr(scope, index_expr)
193 } else if !ast_node_as::<AstExprTable>(node).is_null() {
194 let table = ast_node_as::<AstExprTable>(node);
195 self.check_scope_ptr_ast_expr_table_optional_type_id(
196 scope,
197 table,
198 expected_type,
199 )
200 } else if !ast_node_as::<AstExprUnary>(node).is_null() {
201 let unary = ast_node_as::<AstExprUnary>(node);
202 self.check_scope_ptr_ast_expr_unary(scope, unary)
203 } else if !ast_node_as::<AstExprBinary>(node).is_null() {
204 let binary = ast_node_as::<AstExprBinary>(node);
210 self.check_ast_expr_binary(
211 scope,
212 (*binary).base.base.location,
213 (*binary).op,
214 (*binary).left,
215 (*binary).right,
216 expected_type,
217 )
218 } else if !ast_node_as::<AstExprIfElse>(node).is_null() {
219 let if_else = ast_node_as::<AstExprIfElse>(node);
220 self.check_scope_ptr_ast_expr_if_else_optional_type_id(
221 scope,
222 if_else,
223 expected_type,
224 )
225 } else if !ast_node_as::<AstExprTypeAssertion>(node).is_null() {
226 let type_assert = ast_node_as::<AstExprTypeAssertion>(node);
227 self.check_scope_ptr_ast_expr_type_assertion(scope, type_assert)
228 } else if !ast_node_as::<AstExprInterpString>(node).is_null() {
229 let interp_string = ast_node_as::<AstExprInterpString>(node);
230 self.check_scope_ptr_ast_expr_interp_string(scope, interp_string)
231 } else if !ast_node_as::<AstExprInstantiate>(node).is_null() {
232 let instantiate = ast_node_as::<AstExprInstantiate>(node);
233 self.check_scope_ptr_ast_expr_instantiate(scope, instantiate)
234 } else {
235 let err = ast_node_as::<AstExprError>(node);
236 if !err.is_null() {
237 let expressions = (*err).expressions;
239 for i in 0..expressions.size as usize {
240 let sub_expr = *expressions.data.add(i);
241 self.check_scope_ptr_ast_expr(scope, sub_expr);
242 }
243 Inference::inference_type_id_refinement_id(
244 (*self.builtin_types).errorType,
245 core::ptr::null_mut(),
246 )
247 } else {
248 LUAU_ASSERT!(false);
249 Inference::inference_type_id_refinement_id(
250 self.fresh_type(scope, self.polarity),
251 core::ptr::null_mut(),
252 )
253 }
254 }
255 };
256
257 *self.inferred_expr_cache.get_or_insert(expr) = result.clone();
258
259 LUAU_ASSERT!(!result.ty.is_null());
260
261 if let Some(module) = &self.module {
262 let module_ptr = alloc::sync::Arc::as_ptr(module) as *mut Module;
263 *(*module_ptr)
264 .ast_types
265 .get_or_insert(expr as *const AstExpr) = result.ty;
266 if let Some(et) = expected_type {
267 *(*module_ptr)
268 .ast_expected_types
269 .get_or_insert(expr as *const AstExpr) = et;
270 }
271 }
272
273 result
274 }
275 }
276}
277
278use luaur_ast::records::ast_expr_local::AstExprLocal;