Skip to main content

luaur_analysis/methods/
constraint_generator_check_expr_call.rs

1use crate::enums::polarity::Polarity;
2use crate::enums::type_context::TypeContext;
3use crate::enums::value::Value;
4use crate::functions::add_all_as_dependencies::add_all_as_dependencies;
5use crate::functions::checkpoint::checkpoint;
6use crate::functions::extend_type_pack::extend_type_pack;
7use crate::functions::follow_type::follow;
8use crate::functions::for_each_constraint::for_each_constraint;
9use crate::functions::get_mutable_type::get_mutable_type_id;
10use crate::functions::get_mutable_type_pack::get_mutable_type_pack_id;
11use crate::functions::get_type_alt_j::get_type_id;
12use crate::functions::is_table_union::is_table_union;
13use crate::functions::match_assert::match_assert;
14use crate::functions::match_is_instance_guard::match_is_instance_guard;
15use crate::functions::match_set_metatable::match_set_metatable;
16use crate::functions::should_suppress_errors_type_utils::should_suppress_errors;
17use crate::functions::should_typestate_for_first_argument::should_typestate_for_first_argument;
18use crate::records::blocked_type::BlockedType;
19use crate::records::blocked_type_pack::BlockedTypePack;
20use crate::records::checkpoint::Checkpoint;
21use crate::records::constraint::Constraint;
22use crate::records::constraint_generator::ConstraintGenerator;
23use crate::records::function_call_constraint::FunctionCallConstraint;
24use crate::records::function_check_constraint::FunctionCheckConstraint;
25use crate::records::function_type::FunctionType;
26use crate::records::in_conditional_context::InConditionalContext;
27use crate::records::inference_pack::InferencePack;
28use crate::records::metatable_type::MetatableType;
29use crate::records::module::Module;
30use crate::records::scope::Scope;
31use crate::records::symbol::Symbol;
32use crate::records::union_builder::UnionBuilder;
33use crate::records::union_type::UnionType;
34use crate::records::unpack_constraint::UnpackConstraint;
35use crate::type_aliases::constraint_v::ConstraintV;
36use crate::type_aliases::refinement_id_refinement::RefinementId;
37use crate::type_aliases::scope_ptr_constraint_generator::ScopePtr;
38use crate::type_aliases::type_id::TypeId;
39use crate::type_aliases::type_pack_id::TypePackId;
40use alloc::vec::Vec;
41use luaur_ast::functions::is_l_value::is_l_value;
42use luaur_ast::records::ast_expr::AstExpr;
43use luaur_ast::records::ast_expr_call::AstExprCall;
44use luaur_ast::records::ast_expr_index_name::AstExprIndexName;
45use luaur_ast::records::ast_expr_local::AstExprLocal;
46use luaur_ast::records::ast_expr_varargs::AstExprVarargs;
47use luaur_ast::records::ast_node::AstNode;
48use luaur_ast::rtti::ast_node_as;
49use luaur_common::FFlag;
50use luaur_common::LUAU_ASSERT;
51
52impl ConstraintGenerator {
53    pub fn check_expr_call(
54        &mut self,
55        scope: &ScopePtr,
56        call: *mut AstExprCall,
57        fn_type: TypeId,
58        func_begin: Checkpoint,
59        func_end: Checkpoint,
60    ) -> InferencePack {
61        unsafe {
62            let scope_raw: *mut Scope = scope.as_ref() as *const Scope as *mut Scope;
63
64            let mut expr_args: Vec<*mut AstExpr> = Vec::new();
65
66            let mut return_refinements: Vec<RefinementId> = Vec::new();
67            let mut discriminant_types: Vec<Option<TypeId>> = Vec::new();
68
69            if (*call).self_ {
70                let index_expr = ast_node_as::<AstExprIndexName>((*call).func as *mut AstNode);
71                if index_expr.is_null() {
72                    (*self.ice).ice_string("method call expression has no 'self'");
73                }
74
75                expr_args.push((*index_expr).expr);
76
77                let key = (*self.dfg).get_refinement_key((*index_expr).expr as *const AstExpr);
78                if !key.is_null() {
79                    let discriminant_ty = (*self.arena).add_type(BlockedType::default());
80                    return_refinements.push(
81                        self.refinement_arena
82                            .implicit_proposition_refinement_key_type_id(key, discriminant_ty),
83                    );
84                    discriminant_types.push(Some(discriminant_ty));
85                } else {
86                    discriminant_types.push(None);
87                }
88            }
89
90            for &arg in (*call).args.iter() {
91                expr_args.push(arg);
92
93                let key = (*self.dfg).get_refinement_key(arg as *const AstExpr);
94                if !key.is_null() {
95                    let discriminant_ty = (*self.arena).add_type(BlockedType::default());
96                    return_refinements.push(
97                        self.refinement_arena
98                            .implicit_proposition_refinement_key_type_id(key, discriminant_ty),
99                    );
100                    discriminant_types.push(Some(discriminant_ty));
101                } else {
102                    discriminant_types.push(None);
103                }
104            }
105
106            let expected_types_for_call: Vec<Option<TypeId>> =
107                self.get_expected_call_types_for_function_overloads(fn_type);
108
109            if let Some(module) = &self.module {
110                let module_ptr = alloc::sync::Arc::as_ptr(module) as *mut Module;
111                *(*module_ptr)
112                    .ast_original_call_types
113                    .get_or_insert((*call).func as *const AstNode) = fn_type;
114            }
115
116            let arg_begin_checkpoint = checkpoint(self as *const ConstraintGenerator);
117
118            let mut args: Vec<TypeId> = Vec::new();
119            let mut arg_tail: Option<TypePackId> = None;
120            let mut argument_refinements: Vec<RefinementId> = Vec::new();
121
122            for i in 0..expr_args.len() {
123                let arg = expr_args[i];
124
125                if i == 0 && (*call).self_ {
126                    // The self type has already been computed as a side effect of
127                    // computing fnType.  If computing that did not cause us to exceed a
128                    // recursion limit, we can fetch it from astTypes rather than
129                    // recomputing it.
130                    let self_ty: Option<TypeId> = if let Some(module) = &self.module {
131                        let module_ptr = alloc::sync::Arc::as_ptr(module) as *mut Module;
132                        (*module_ptr)
133                            .ast_types
134                            .find(&(expr_args[0] as *const AstExpr))
135                            .copied()
136                    } else {
137                        None
138                    };
139                    if let Some(ty) = self_ty {
140                        args.push(ty);
141                    } else {
142                        args.push(self.fresh_type(scope, Polarity::Negative));
143                    }
144                } else if i < expr_args.len() - 1
145                    || !((*(arg as *mut AstNode)).is::<AstExprCall>()
146                        || (*(arg as *mut AstNode)).is::<AstExprVarargs>())
147                {
148                    let mut expected_type: Option<TypeId> = None;
149                    if i < expected_types_for_call.len() {
150                        expected_type = expected_types_for_call[i];
151                    }
152                    if i == 0 && match_assert(&*call) {
153                        let _flipper = InConditionalContext::new(
154                            &mut self.type_context,
155                            TypeContext::Condition,
156                        );
157                        let inference = self.check_scope_ptr_ast_expr_optional_type_id_bool_bool(
158                            scope,
159                            arg,
160                            expected_type,
161                            false,
162                            false,
163                        );
164                        args.push(inference.ty);
165                        argument_refinements.push(inference.refinement);
166                    } else {
167                        let inference = self.check_scope_ptr_ast_expr_optional_type_id_bool_bool(
168                            scope,
169                            arg,
170                            expected_type,
171                            false,
172                            false,
173                        );
174                        args.push(inference.ty);
175                        argument_refinements.push(inference.refinement);
176                    }
177                } else {
178                    let mut expected_types: Vec<Option<TypeId>> = Vec::new();
179                    if i < expected_types_for_call.len() {
180                        expected_types.extend_from_slice(&expected_types_for_call[i..]);
181                    }
182                    let pack = self.check_pack_scope_ptr_ast_expr_vector_optional_type_id_bool(
183                        scope,
184                        arg,
185                        &expected_types,
186                        true,
187                    );
188                    arg_tail = Some(pack.tp);
189                    argument_refinements.extend(pack.refinements.iter().copied());
190                }
191            }
192
193            let arg_end_checkpoint = checkpoint(self as *const ConstraintGenerator);
194
195            if FFlag::DebugLuauUserDefinedClasses.get() {
196                let instance_guard = match_is_instance_guard(&*call, &*self.dfg);
197                if !instance_guard.is_null() {
198                    if args.len() >= 2 {
199                        // The class type may not be solved yet (e.g. `A.Point` from a
200                        // required module).
201                        let objectof_inst = self.create_type_function_instance(
202                            &(*self.builtin_types).typeFunctions.objectof_func,
203                            alloc::vec![args[1]],
204                            Vec::new(),
205                            scope,
206                            (*call).base.base.location,
207                        );
208                        return_refinements.push(
209                            self.refinement_arena
210                                .implicit_proposition_refinement_key_type_id(
211                                    instance_guard,
212                                    objectof_inst,
213                                ),
214                        );
215                    }
216                }
217            }
218
219            if match_set_metatable(&*call) {
220                let mut arg_tail_pack = crate::records::type_pack::TypePack {
221                    head: Vec::new(),
222                    tail: None,
223                };
224                if arg_tail.is_some() && args.len() < 2 {
225                    arg_tail_pack = extend_type_pack(
226                        &mut *self.arena,
227                        self.builtin_types,
228                        arg_tail.unwrap(),
229                        2 - args.len(),
230                        Vec::new(),
231                    );
232                }
233
234                let mut target: TypeId;
235                let mut mt: TypeId;
236
237                if args.len() + arg_tail_pack.head.len() == 2 {
238                    target = if args.len() > 0 {
239                        args[0]
240                    } else {
241                        arg_tail_pack.head[0]
242                    };
243                    mt = if args.len() > 1 {
244                        args[1]
245                    } else {
246                        arg_tail_pack.head[if args.len() == 0 { 1 } else { 0 }]
247                    };
248                } else {
249                    let mut unpacked_types: Vec<TypeId> = Vec::new();
250                    if args.len() > 0 {
251                        target = follow(args[0]);
252                    } else {
253                        target = (*self.arena).add_type(BlockedType::default());
254                        unpacked_types.push(target);
255                    }
256
257                    mt = (*self.arena).add_type(BlockedType::default());
258                    unpacked_types.push(mt);
259
260                    let c = self.add_constraint_scope_ptr_location_constraint_v(
261                        scope,
262                        (*call).base.base.location,
263                        ConstraintV::Unpack(UnpackConstraint {
264                            result_pack: unpacked_types,
265                            source_pack: arg_tail.unwrap(),
266                        }),
267                    );
268                    (*get_mutable_type_id::<BlockedType>(mt)).setOwner(c as *const Constraint);
269                    let b = get_mutable_type_id::<BlockedType>(target);
270                    if !b.is_null() && (*b).getOwner().is_null() {
271                        (*b).setOwner(c as *const Constraint);
272                    }
273                }
274
275                LUAU_ASSERT!(!target.is_null());
276                LUAU_ASSERT!(!mt.is_null());
277
278                target = follow(target);
279                if should_suppress_errors(self.normalizer, mt).value == Value::Suppress {
280                    mt = (*self.builtin_types).anyType;
281                }
282
283                let target_expr = *(*call).args.data.add(0);
284
285                let result_ty: TypeId;
286
287                if is_table_union(target) {
288                    let target_union = get_type_id::<UnionType>(target);
289                    let mut ub = UnionBuilder::union_builder(self.arena, self.builtin_types);
290
291                    for &ty in &(*target_union).options {
292                        ub.add((*self.arena).add_type(MetatableType {
293                            table: ty,
294                            metatable: mt,
295                            syntheticName: None,
296                        }));
297                    }
298
299                    result_ty = ub.build();
300                } else {
301                    result_ty = (*self.arena).add_type(MetatableType {
302                        table: target,
303                        metatable: mt,
304                        syntheticName: None,
305                    });
306                }
307
308                let target_local = ast_node_as::<AstExprLocal>(target_expr as *mut AstNode);
309                if !target_local.is_null() {
310                    let symbol = Symbol::from_local((*target_local).local);
311                    // C++ `scope->bindings[targetLocal->local].typeId = resultTy` — the
312                    // operator[] default-constructs a Binding when absent.
313                    if let Some(binding) = (*scope_raw).bindings.get_mut(&symbol) {
314                        binding.type_id = result_ty;
315                    } else {
316                        (*scope_raw).bindings.insert(
317                            symbol,
318                            crate::records::binding::Binding {
319                                type_id: result_ty,
320                                location: luaur_ast::records::location::Location::default(),
321                                deprecated: false,
322                                deprecated_suggestion: alloc::string::String::new(),
323                                documentation_symbol: None,
324                            },
325                        );
326                    }
327
328                    let def = (*self.dfg).get_def(target_expr as *const AstExpr);
329                    *(*scope_raw).lvalue_types.get_or_insert(def) = result_ty; // TODO: typestates: track this as an assignment
330                    self.update_r_value_refinements_scope_def_id_type_id(scope_raw, def, result_ty); // TODO: typestates: track this as an assignment
331
332                    // HACK: If we have a targetLocal, it has already been added to the
333                    // inferredBindings table.  We want to replace it so that we don't
334                    // infer a weird union like tbl | { @metatable something, tbl }
335                    let ib_symbol = Symbol::from_local((*target_local).local);
336                    if let Some(ib) = self.inferred_bindings.find_mut(&ib_symbol) {
337                        ib.types.erase_type_id(target);
338                    }
339
340                    self.record_inferred_binding((*target_local).local, result_ty);
341                }
342
343                return InferencePack {
344                    tp: (*self.arena).add_type_pack_initializer_list_type_id(&[result_ty]),
345                    refinements: alloc::vec![self
346                        .refinement_arena
347                        .variadic_refinement_ids(&return_refinements)],
348                };
349            }
350
351            if should_typestate_for_first_argument(&*call)
352                && (*call).args.size > 0
353                && is_l_value(*(*call).args.data.add(0) as *const AstExpr)
354            {
355                let target_expr = *(*call).args.data.add(0);
356                let result_ty = (*self.arena).add_type(BlockedType::default());
357
358                if let Some(def) = (*self.dfg).get_def_optional(target_expr as *const AstExpr) {
359                    *(*scope_raw).lvalue_types.get_or_insert(def) = result_ty;
360                    self.update_r_value_refinements_scope_def_id_type_id(scope_raw, def, result_ty);
361                }
362            }
363
364            if match_assert(&*call) && !argument_refinements.is_empty() {
365                self.apply_refinements(
366                    scope,
367                    (**(*call).args.data.add(0)).base.location,
368                    argument_refinements[0],
369                );
370            }
371
372            // TODO: How do expectedTypes play into this?  Do they?
373            let rets: TypePackId = (*self.arena).add_type_pack_t(BlockedTypePack {
374                index: 0,
375                owner: core::ptr::null_mut(),
376            });
377            let arg_pack: TypePackId = self.add_type_pack(args, arg_tail);
378            let ftv = FunctionType::function_type_new(arg_pack, rets, None, (*call).self_);
379
380            let (explicit_type_ids, explicit_type_pack_ids): (Vec<TypeId>, Vec<TypePackId>) =
381                if FFlag::LuauExplicitTypeInstantiationSupport.get()
382                    && (*call).type_arguments.size != 0
383                {
384                    self.resolve_type_arguments(scope_raw, (*call).type_arguments)
385                } else {
386                    (Vec::new(), Vec::new())
387                };
388
389            // we don't need ftv after building argPack/rets except to keep the FunctionType
390            // shape; the C++ uses `ftv` only via the constraints below which reference
391            // fnType/argPack/rets directly.
392            let _ = ftv;
393
394            /*
395             * To make bidirectional type checking work, we need to solve these constraints in a particular order:
396             *
397             * 1. Solve the function type
398             * 2. Propagate type information from the function type to the argument typeArguments
399             * 3. Solve the argument typeArguments
400             * 4. Solve the call
401             */
402
403            let check_constraint: *mut Constraint = self
404                .add_constraint_scope_ptr_location_constraint_v(
405                    scope,
406                    (*(*call).func).base.location,
407                    ConstraintV::FunctionCheck(FunctionCheckConstraint {
408                        fn_type,
409                        args_pack: arg_pack,
410                        call_site: call,
411                        ast_types: self
412                            .module
413                            .as_ref()
414                            .map(|m| {
415                                let mp = alloc::sync::Arc::as_ptr(m) as *mut Module;
416                                &(*mp).ast_types as *const _
417                            })
418                            .unwrap_or(core::ptr::null()),
419                        ast_expected_types: self
420                            .module
421                            .as_ref()
422                            .map(|m| {
423                                let mp = alloc::sync::Arc::as_ptr(m) as *mut Module;
424                                &(*mp).ast_expected_types as *const _
425                            })
426                            .unwrap_or(core::ptr::null()),
427                    }),
428                );
429
430            if FFlag::LuauConstraintGraph.get() {
431                add_all_as_dependencies(func_begin, func_end, self, check_constraint);
432            } else {
433                for_each_constraint(func_begin, func_end, self, |constraint| {
434                    (*check_constraint).deprecated_dependencies.push(constraint);
435                });
436            }
437
438            let call_constraint: *mut Constraint = self
439                .add_constraint_scope_ptr_location_constraint_v(
440                    scope,
441                    (*(*call).func).base.location,
442                    ConstraintV::FunctionCall(FunctionCallConstraint {
443                        fn_type,
444                        args_pack: arg_pack,
445                        result: rets,
446                        call_site: call,
447                        discriminant_types,
448                        type_arguments: explicit_type_ids,
449                        type_pack_arguments: explicit_type_pack_ids,
450                        ast_overload_resolved_types: self
451                            .module
452                            .as_ref()
453                            .map(|m| {
454                                let mp = alloc::sync::Arc::as_ptr(m) as *mut Module;
455                                &mut (*mp).ast_overload_resolved_types as *mut _
456                            })
457                            .unwrap_or(core::ptr::null_mut()),
458                    }),
459                );
460
461            (*get_mutable_type_pack_id::<BlockedTypePack>(rets)).owner = call_constraint;
462
463            if FFlag::LuauConstraintGraph.get() {
464                (*self.cgraph).add_dependency_of_constraint_constraint(
465                    &mut *check_constraint,
466                    &mut *call_constraint,
467                );
468                for_each_constraint(
469                    arg_begin_checkpoint,
470                    arg_end_checkpoint,
471                    self,
472                    |constraint| {
473                        (*self.cgraph).add_dependency_of_constraint_constraint(
474                            &mut *check_constraint,
475                            &mut *constraint,
476                        );
477                        (*self.cgraph).add_dependency_of_constraint_constraint(
478                            &mut *constraint,
479                            &mut *call_constraint,
480                        );
481                    },
482                );
483            } else {
484                (*call_constraint)
485                    .deprecated_dependencies
486                    .push(check_constraint);
487                for_each_constraint(
488                    arg_begin_checkpoint,
489                    arg_end_checkpoint,
490                    self,
491                    |constraint| {
492                        (*constraint).deprecated_dependencies.push(check_constraint);
493                        (*call_constraint).deprecated_dependencies.push(constraint);
494                    },
495                );
496            }
497
498            InferencePack {
499                tp: rets,
500                refinements: alloc::vec![self
501                    .refinement_arena
502                    .variadic_refinement_ids(&return_refinements)],
503            }
504        }
505    }
506}