Skip to main content

luaur_analysis/methods/
constraint_generator_visit_module_root.rs

1use crate::enums::control_flow::ControlFlow;
2use crate::enums::polarity::Polarity;
3use crate::functions::add_all_as_dependencies::add_all_as_dependencies;
4use crate::functions::as_mutable_type::as_mutable_type_id;
5use crate::functions::checkpoint::checkpoint;
6use crate::functions::follow_type::follow_type_id;
7use crate::functions::for_each_constraint::for_each_constraint;
8use crate::functions::get_mutable_type::getMutable;
9use crate::functions::get_type_alt_j::get_type_id;
10use crate::records::blocked_type::BlockedType;
11use crate::records::constraint::Constraint;
12use crate::records::constraint_generator::ConstraintGenerator;
13use crate::records::function_type::FunctionType;
14use crate::records::generalization_constraint::GeneralizationConstraint;
15use crate::records::interior_free_types::InteriorFreeTypes;
16use crate::records::module::Module;
17use crate::records::pack_subtype_constraint::PackSubtypeConstraint;
18use crate::records::scope::Scope;
19use crate::records::simplify_constraint::SimplifyConstraint;
20use crate::type_aliases::constraint_v::ConstraintV;
21use crate::type_aliases::scope_ptr_constraint_generator::ScopePtr;
22use crate::type_aliases::type_id::TypeId;
23use crate::type_aliases::type_variant::TypeVariant;
24use alloc::rc::Rc;
25use alloc::sync::Arc;
26use alloc::vec::Vec;
27use luaur_ast::records::ast_node::AstNode;
28use luaur_ast::records::ast_stat_block::AstStatBlock;
29use luaur_ast::records::location::Location;
30use luaur_common::macros::luau_assert::LUAU_ASSERT;
31use luaur_common::macros::luau_timetrace_scope::LUAU_TIMETRACE_SCOPE;
32use luaur_common::FFlag;
33
34impl ConstraintGenerator {
35    pub fn visit_module_root(&mut self, block: *mut AstStatBlock) {
36        LUAU_TIMETRACE_SCOPE!("ConstraintGenerator::visitModuleRoot", "Typechecking");
37
38        LUAU_ASSERT!(self.scopes.is_empty());
39        LUAU_ASSERT!(self.root_scope.is_null());
40
41        let scope: ScopePtr = Arc::new(Scope::new(self.global_scope.as_ref().unwrap(), 0));
42        self.root_scope = scope.as_ref() as *const Scope as *mut Scope;
43        self.scopes
44            .push((unsafe { (*block).base.base.location }, scope.clone()));
45        unsafe {
46            (*self.root_scope).location = (*block).base.base.location;
47        }
48        if let Some(module) = &self.module {
49            let module_ptr = Arc::as_ptr(module) as *mut Module;
50            unsafe {
51                *(*module_ptr)
52                    .ast_scopes
53                    .get_or_insert(block as *const AstNode) =
54                    scope.as_ref() as *const Scope as *mut Scope;
55            }
56        }
57
58        self.interior_free_types.push(InteriorFreeTypes::default());
59
60        let local_type_function_scope: ScopePtr =
61            Arc::new(Scope::new(self.type_function_scope.as_ref().unwrap(), 0));
62        unsafe {
63            let lhs = local_type_function_scope.as_ref() as *const Scope as *mut Scope;
64            (*lhs).location = (*block).base.base.location;
65        }
66        unsafe {
67            (*self.type_function_runtime).root_scope = local_type_function_scope;
68        }
69
70        let return_type = self.fresh_type_pack(&scope, Polarity::Positive);
71        unsafe {
72            (*self.root_scope).return_type = return_type;
73        }
74        let module_fn_ty = unsafe {
75            (*self.arena).add_type(FunctionType::function_type_new(
76                (*self.builtin_types).anyTypePack,
77                return_type,
78                None,
79                false,
80            ))
81        };
82
83        self.prepopulate_global_scope(&scope, block);
84
85        let start = checkpoint(self);
86
87        let cf = self.visit_block_without_child_scope(self.root_scope, block);
88        if cf == ControlFlow::None {
89            let empty_type_pack = unsafe { (*self.builtin_types).emptyTypePack };
90            self.add_constraint_scope_ptr_location_constraint_v(
91                &scope,
92                unsafe { (*block).base.base.location },
93                ConstraintV::PackSubtype(PackSubtypeConstraint {
94                    sub_pack: empty_type_pack,
95                    super_pack: return_type,
96                    returns: false,
97                }),
98            );
99        }
100
101        let end = checkpoint(self);
102
103        let result = unsafe { (*self.arena).add_type(BlockedType::default()) };
104        let gen_constraint = self.add_constraint_scope_ptr_location_constraint_v(
105            &scope,
106            unsafe { (*block).base.base.location },
107            ConstraintV::Generalization(GeneralizationConstraint {
108                generalized_type: result,
109                source_type: module_fn_ty,
110                interior_types: Vec::new(),
111                has_deprecated_attribute: false,
112                deprecated_info: Default::default(),
113                no_generics: true,
114            }),
115        );
116
117        unsafe {
118            (*self.root_scope).interior_free_types =
119                Some(self.interior_free_types.last().unwrap().types.clone());
120            (*self.root_scope).interior_free_type_packs =
121                Some(self.interior_free_types.last().unwrap().type_packs.clone());
122        }
123
124        unsafe {
125            let blocked = getMutable::<BlockedType>(result);
126            (*blocked).set_owner(gen_constraint);
127        }
128
129        if FFlag::LuauConstraintGraph.get() {
130            add_all_as_dependencies(start, end, self, gen_constraint);
131        } else {
132            for_each_constraint(start, end, self, |c: *mut Constraint| unsafe {
133                (*gen_constraint).deprecated_dependencies.push(c);
134            });
135        }
136
137        self.interior_free_types.pop();
138
139        self.fill_in_inferred_bindings(&scope, block);
140
141        if !self.logger.is_null() {
142            unsafe {
143                (*self.logger).capture_generation_module(self.module.clone().unwrap());
144            }
145        }
146
147        let local_types_pairs: Vec<(TypeId, Vec<TypeId>)> = self
148            .local_types
149            .iter()
150            .map(|(ty, domain)| (*ty, domain.order.clone()))
151            .collect();
152        for (ty, domain) in local_types_pairs {
153            // FIXME: This isn't the most efficient thing.
154            let mut domain_ty = unsafe { (*self.builtin_types).neverType };
155            for d in domain {
156                let d_followed = unsafe { follow_type_id(d) };
157                if d_followed == ty {
158                    continue;
159                }
160                domain_ty =
161                    self.simplify_union(scope.clone(), Location::default(), domain_ty, d_followed);
162            }
163
164            LUAU_ASSERT!(!unsafe { get_type_id::<BlockedType>(ty) }.is_null());
165            unsafe {
166                (*as_mutable_type_id(ty)).ty = TypeVariant::Bound(domain_ty);
167            }
168        }
169
170        let unions_to_simplify = self.unions_to_simplify.clone();
171        for ty in unions_to_simplify {
172            self.add_constraint_scope_ptr_location_constraint_v(
173                &scope,
174                unsafe { (*block).base.base.location },
175                ConstraintV::Simplify(SimplifyConstraint { ty }),
176            );
177        }
178    }
179}