Skip to main content

luaur_analysis/methods/
constraint_solver_try_dispatch_iterable_table.rs

1use crate::enums::polarity::Polarity;
2use crate::enums::table_state::TableState;
3use crate::functions::extend_type_pack::extend_type_pack;
4use crate::functions::find_metatable_entry::find_metatable_entry;
5use crate::functions::follow_type::follow_type_id;
6use crate::functions::fresh_type::fresh_type;
7use crate::functions::get_mutable_type::getMutable;
8use crate::functions::get_type_alt_j::get_type_id;
9use crate::functions::instantiate::instantiate;
10use crate::functions::track_interior_free_type::track_interior_free_type;
11use crate::records::any_type::AnyType;
12use crate::records::constraint::Constraint;
13use crate::records::constraint_solver::ConstraintSolver;
14use crate::records::free_type::FreeType;
15use crate::records::function_type::FunctionType;
16use crate::records::iterable_constraint::IterableConstraint;
17use crate::records::metatable_type::MetatableType;
18use crate::records::never_type::NeverType;
19use crate::records::primitive_type::PrimitiveType;
20use crate::records::reduce_constraint::ReduceConstraint;
21use crate::records::table_indexer::TableIndexer;
22use crate::records::table_type::TableType;
23use crate::records::type_check_limits::TypeCheckLimits;
24use crate::records::type_level::TypeLevel;
25use crate::records::unification_too_complex::UnificationTooComplex;
26use crate::type_aliases::constraint_v::ConstraintV;
27use crate::type_aliases::props_type::Props;
28use crate::type_aliases::type_error_data::TypeErrorData;
29use crate::type_aliases::type_id::TypeId;
30use core::ptr::NonNull;
31use luaur_ast::records::location::Location;
32use luaur_common::FFlag;
33
34impl ConstraintSolver {
35    pub fn try_dispatch_iterable_table(
36        &mut self,
37        iterator_ty: TypeId,
38        c: &IterableConstraint,
39        constraint: *const Constraint,
40        force: bool,
41    ) -> bool {
42        let iterator_ty = unsafe { follow_type_id(iterator_ty) };
43
44        if unsafe { !get_type_id::<FreeType>(iterator_ty).is_null() } {
45            let scope = unsafe { (*constraint).scope };
46            let key_ty = unsafe {
47                fresh_type(
48                    &mut *self.arena,
49                    &*self.builtin_types,
50                    scope,
51                    Polarity::Mixed,
52                )
53            };
54            let value_ty = unsafe {
55                fresh_type(
56                    &mut *self.arena,
57                    &*self.builtin_types,
58                    scope,
59                    Polarity::Mixed,
60                )
61            };
62            track_interior_free_type(scope, key_ty);
63            track_interior_free_type(scope, value_ty);
64
65            let props = Props::default();
66            let table_ty = unsafe {
67                (*self.arena).add_type(
68                    TableType::table_type_props_optional_table_indexer_type_level_scope_table_state(
69                        &props,
70                        Some(TableIndexer {
71                            index_type: key_ty,
72                            index_result_type: value_ty,
73                            is_read_only: false,
74                        }),
75                        TypeLevel::default(),
76                        scope,
77                        TableState::Sealed,
78                    ),
79                )
80            };
81
82            self.constraint_solver_unify(constraint, iterator_ty, table_ty);
83
84            let mut it = c.variables.iter();
85            if let Some(ty) = it.next() {
86                self.bind_not_null_constraint_type_id_type_id(constraint, *ty, key_ty);
87            }
88            if let Some(ty) = it.next() {
89                self.bind_not_null_constraint_type_id_type_id(constraint, *ty, value_ty);
90            }
91
92            return true;
93        }
94
95        if unsafe { !get_type_id::<AnyType>(iterator_ty).is_null() } {
96            self.unpack_iterable_variables(constraint, c, unsafe { (*self.builtin_types).anyType });
97            return true;
98        }
99
100        if unsafe { !get_type_id::<NeverType>(iterator_ty).is_null() } {
101            self.unpack_iterable_variables(constraint, c, unsafe {
102                (*self.builtin_types).neverType
103            });
104            return true;
105        }
106
107        // Irksome: I don't think we have any way to guarantee that this table
108        // type never has a metatable.
109
110        if let Some(iterator_table) = unsafe { get_type_id::<TableType>(iterator_ty).as_ref() } {
111            if iterator_table.state == TableState::Free && !force {
112                return self.block_type_id_not_null_constraint(iterator_ty, constraint);
113            }
114
115            if let Some(indexer) = iterator_table.indexer {
116                let value_type = if FFlag::LuauRefineNilFromTableIndexerResultType.get() {
117                    let intersection_with_not_nil = unsafe {
118                        (*self.arena).add_type_function_type_function_initializer_list_type_id(
119                            &(*self.builtin_types).typeFunctions.intersect_func,
120                            &[indexer.index_result_type, (*self.builtin_types).notNilType],
121                        )
122                    };
123
124                    unsafe {
125                        self.push_constraint(
126                            NonNull::new((*constraint).scope).unwrap(),
127                            (*constraint).location,
128                            ConstraintV::Reduce(ReduceConstraint {
129                                ty: intersection_with_not_nil,
130                            }),
131                        );
132                    }
133
134                    intersection_with_not_nil
135                } else {
136                    indexer.index_result_type
137                };
138
139                let mut expected_variables = alloc::vec![indexer.index_type, value_type];
140                while expected_variables.len() < c.variables.len() {
141                    expected_variables.push(unsafe { (*self.builtin_types).errorType });
142                }
143
144                for (variable, expected) in c.variables.iter().zip(expected_variables.iter()) {
145                    self.constraint_solver_unify(constraint, *variable, *expected);
146                    self.bind_not_null_constraint_type_id_type_id(constraint, *variable, *expected);
147                }
148            } else {
149                self.unpack_iterable_variables(constraint, c, unsafe {
150                    (*self.builtin_types).errorType
151                });
152            }
153
154            return true;
155        }
156
157        // else if (std::optional<TypeId> iterFn = findMetatableEntry(builtinTypes, errors, iteratorTy, "__iter", Location{}))
158        let iter_fn = find_metatable_entry(
159            self.builtin_types,
160            &mut self.errors,
161            iterator_ty,
162            "__iter",
163            Location::default(),
164        );
165        if let Some(iter_fn) = iter_fn {
166            if self.is_blocked_type_id(iter_fn) {
167                return self.block_type_id_not_null_constraint(iter_fn, constraint);
168            }
169
170            let scope = unsafe { (*constraint).scope };
171            let instantiated_iter_fn = instantiate(
172                self.builtin_types,
173                self.arena,
174                &mut self.limits as *mut TypeCheckLimits,
175                scope,
176                iter_fn,
177            );
178
179            if let Some(instantiated_iter_fn) = instantiated_iter_fn {
180                let iter_ftv = unsafe { get_type_id::<FunctionType>(instantiated_iter_fn) };
181                if !iter_ftv.is_null() {
182                    let iter_ftv = unsafe { &*iter_ftv };
183
184                    let expected_iter_args = unsafe {
185                        (*self.arena).add_type_pack_initializer_list_type_id(&[iterator_ty])
186                    };
187                    self.constraint_solver_unify(
188                        constraint,
189                        iter_ftv.arg_types,
190                        expected_iter_args,
191                    );
192
193                    let iter_rets = extend_type_pack(
194                        unsafe { &mut *self.arena },
195                        self.builtin_types,
196                        iter_ftv.ret_types,
197                        2,
198                        alloc::vec::Vec::new(),
199                    );
200
201                    if iter_rets.head.len() < 1 {
202                        // We've done what we can; this will get reported as an
203                        // error by the type checker.
204                        return true;
205                    }
206
207                    let next_fn_ty = iter_rets.head[0];
208
209                    let instantiated_next_fn = instantiate(
210                        self.builtin_types,
211                        self.arena,
212                        &mut self.limits as *mut TypeCheckLimits,
213                        scope,
214                        next_fn_ty,
215                    );
216
217                    if let Some(instantiated_next_fn) = instantiated_next_fn {
218                        let next_fn = unsafe { get_type_id::<FunctionType>(instantiated_next_fn) };
219
220                        // If nextFn is nullptr, then the iterator function has an improper signature.
221                        if !next_fn.is_null() {
222                            let ret_types = unsafe { (*next_fn).ret_types };
223                            self.unpack_and_assign(
224                                c.variables.clone(),
225                                ret_types,
226                                NonNull::new(constraint as *mut Constraint).unwrap(),
227                            );
228                        }
229
230                        return true;
231                    } else {
232                        let location = unsafe { (*constraint).location };
233                        self.report_error_type_error_data_location(
234                            TypeErrorData::UnificationTooComplex(UnificationTooComplex::default()),
235                            &location,
236                        );
237                    }
238                } else {
239                    // TODO: Support __call and function overloads (what does an overload even mean for this?)
240                }
241            } else {
242                let location = unsafe { (*constraint).location };
243                self.report_error_type_error_data_location(
244                    TypeErrorData::UnificationTooComplex(UnificationTooComplex::default()),
245                    &location,
246                );
247            }
248
249            return true;
250        }
251
252        // else if (auto iteratorMetatable = get<MetatableType>(iteratorTy))
253        if let Some(iterator_metatable) =
254            unsafe { get_type_id::<MetatableType>(iterator_ty).as_ref() }
255        {
256            // If the metatable does not contain a `__iter` metamethod, then we iterate over the table part of the metatable.
257            return self.try_dispatch_iterable_table(
258                iterator_metatable.table,
259                c,
260                constraint,
261                force,
262            );
263        }
264
265        if let Some(primitive_ty) = unsafe { get_type_id::<PrimitiveType>(iterator_ty).as_ref() } {
266            if primitive_ty.r#type == PrimitiveType::Table {
267                self.unpack_iterable_variables(constraint, c, unsafe {
268                    (*self.builtin_types).unknownType
269                });
270                return true;
271            }
272        }
273
274        self.unpack_iterable_variables(constraint, c, unsafe { (*self.builtin_types).errorType });
275
276        true
277    }
278
279    fn unpack_iterable_variables(
280        &mut self,
281        constraint: *const Constraint,
282        c: &IterableConstraint,
283        ty: TypeId,
284    ) {
285        for var_ty in &c.variables {
286            self.bind_not_null_constraint_type_id_type_id(constraint, *var_ty, ty);
287        }
288    }
289}