Skip to main content

luaur_analysis/methods/
normalizer_intersection_of_tables.rs

1//! Source: `Analysis/src/Normalize.cpp:2711-2967` (hand-ported)
2use crate::enums::table_state::TableState;
3use crate::functions::follow_type::follow_type_id;
4use crate::functions::get_type_alt_j::get_type_id;
5use crate::functions::is_prim::is_prim;
6use crate::functions::simplify_intersection_simplify::simplify_intersection;
7use crate::records::any_type::AnyType;
8use crate::records::metatable_type::MetatableType;
9use crate::records::never_type::NeverType;
10use crate::records::normalizer::Normalizer;
11use crate::records::primitive_type::Type as PrimType;
12use crate::records::property_type::Property;
13use crate::records::table_indexer::TableIndexer;
14use crate::records::table_type::TableType;
15use crate::type_aliases::seen_table_prop_pairs::SeenTablePropPairs;
16use crate::type_aliases::type_id::TypeId;
17use luaur_common::records::dense_hash_set::DenseHashSet;
18use luaur_common::FFlag;
19
20/// RAII guard mirroring C++ `RecursionCounter _rc(&sharedState->counters.recursionCount)`.
21struct RcGuard {
22    count: *mut i32,
23}
24
25impl RcGuard {
26    fn new(count: *mut i32) -> Self {
27        unsafe {
28            *count += 1;
29        }
30        RcGuard { count }
31    }
32}
33
34impl Drop for RcGuard {
35    fn drop(&mut self) {
36        unsafe {
37            *self.count -= 1;
38        }
39    }
40}
41
42impl Normalizer {
43    pub fn intersection_of_tables(
44        &mut self,
45        here: TypeId,
46        there: TypeId,
47        seen_table_prop_pairs: &mut SeenTablePropPairs,
48        seen_set: &mut DenseHashSet<TypeId>,
49    ) -> Option<TypeId> {
50        self.consume_fuel();
51
52        if here == there {
53            return Some(here);
54        }
55
56        let _rc = RcGuard::new(unsafe { &mut (*self.shared_state).counters.recursion_count });
57        let recursion_limit = unsafe { (*self.shared_state).counters.recursion_limit };
58        let recursion_count = unsafe { (*self.shared_state).counters.recursion_count };
59        if recursion_limit > 0 && recursion_limit < recursion_count {
60            return None;
61        }
62
63        if is_prim(here, PrimType::Table) {
64            return Some(there);
65        } else if is_prim(there, PrimType::Table) {
66            return Some(here);
67        }
68
69        if !unsafe { get_type_id::<NeverType>(here).is_null() } {
70            return Some(there);
71        } else if !unsafe { get_type_id::<NeverType>(there).is_null() } {
72            return Some(here);
73        } else if !unsafe { get_type_id::<AnyType>(here).is_null() } {
74            return Some(there);
75        } else if !unsafe { get_type_id::<AnyType>(there).is_null() } {
76            return Some(here);
77        }
78
79        let mut htable = here;
80        let mut hmtable: TypeId = core::ptr::null();
81        if let Some(hmtv) = unsafe { get_type_id::<MetatableType>(here).as_ref() } {
82            htable = unsafe { follow_type_id(hmtv.table()) };
83            hmtable = unsafe { follow_type_id(hmtv.metatable()) };
84        }
85        let mut ttable = there;
86        let mut tmtable: TypeId = core::ptr::null();
87        if let Some(tmtv) = unsafe { get_type_id::<MetatableType>(there).as_ref() } {
88            ttable = unsafe { follow_type_id(tmtv.table()) };
89            tmtable = unsafe { follow_type_id(tmtv.metatable()) };
90        }
91
92        let httv = unsafe { get_type_id::<TableType>(htable) };
93        if httv.is_null() {
94            return None;
95        }
96        let tttv = unsafe { get_type_id::<TableType>(ttable) };
97        if tttv.is_null() {
98            return None;
99        }
100        let httv = unsafe { &*httv };
101        let tttv = unsafe { &*tttv };
102
103        if httv.state == TableState::Free || tttv.state == TableState::Free {
104            return None;
105        }
106        if httv.state == TableState::Generic || tttv.state == TableState::Generic {
107            return None;
108        }
109
110        let mut state = httv.state;
111        if tttv.state == TableState::Unsealed {
112            state = tttv.state;
113        }
114
115        // TypeLevel max(a, b) == if a.subsumes(b) { b } else { a }   (Unifiable.h:62)
116        let level = if httv.level.subsumes(&tttv.level) {
117            tttv.level
118        } else {
119            httv.level
120        };
121        // Scope* max(a, b)
122        let scope = crate::functions::max_scope::max(httv.scope, tttv.scope);
123
124        let mut result: Option<TableType> = None;
125        let mut here_sub_there = true;
126        let mut there_sub_here = true;
127
128        for (name, hprop) in httv.props.iter() {
129            let mut prop: Property = hprop.clone();
130            let tfound = tttv.props.get(name);
131            match tfound {
132                None => {
133                    there_sub_here = false;
134                }
135                Some(tprop) => {
136                    // TODO: variance issues here, which can't be fixed until we have read/write property types
137                    if self.use_new_luau_solver() {
138                        if let Some(hread) = hprop.read_ty {
139                            if let Some(tread) = tprop.read_ty {
140                                let ty = simplify_intersection(
141                                    self.builtin_types,
142                                    self.arena,
143                                    hread,
144                                    tread,
145                                )
146                                .result;
147
148                                // If any property is going to get mapped to `never`, we can just call the entire table `never`.
149                                if !unsafe { get_type_id::<NeverType>(ty).is_null() } {
150                                    return Some(unsafe { (*self.builtin_types).neverType });
151                                }
152
153                                prop.read_ty = Some(ty);
154                                here_sub_there &= ty == hread;
155                                there_sub_here &= ty == tread;
156                            } else {
157                                prop.read_ty = Some(hread);
158                                there_sub_here = false;
159                            }
160                        } else if let Some(tread) = tprop.read_ty {
161                            prop.read_ty = Some(tread);
162                            here_sub_there = false;
163                        }
164
165                        if let Some(hwrite) = hprop.write_ty {
166                            if let Some(twrite) = tprop.write_ty {
167                                let w = simplify_intersection(
168                                    self.builtin_types,
169                                    self.arena,
170                                    hwrite,
171                                    twrite,
172                                )
173                                .result;
174                                prop.write_ty = Some(w);
175                                here_sub_there &= w == hwrite;
176                                there_sub_here &= w == twrite;
177                            } else {
178                                prop.write_ty = Some(hwrite);
179                                there_sub_here = false;
180                            }
181                        } else if let Some(twrite) = tprop.write_ty {
182                            prop.write_ty = Some(twrite);
183                            here_sub_there = false;
184                        }
185                    } else {
186                        let h_dep = hprop.type_deprecated();
187                        let t_dep = tprop.type_deprecated();
188                        let inter = self.intersection_type(h_dep, t_dep);
189                        prop.set_type(inter);
190                        here_sub_there &= prop.type_deprecated() == h_dep;
191                        there_sub_here &= prop.type_deprecated() == t_dep;
192                    }
193                }
194            }
195
196            // TODO: string indexers
197
198            if prop.read_ty.is_some() || prop.write_ty.is_some() {
199                if result.is_none() {
200                    result = Some(TableType::table_type_table_state_type_level_scope(
201                        state, level, scope,
202                    ));
203                }
204                result.as_mut().unwrap().props.insert(name.clone(), prop);
205            }
206        }
207
208        for (name, tprop) in tttv.props.iter() {
209            if !httv.props.contains_key(name) {
210                if result.is_none() {
211                    result = Some(TableType::table_type_table_state_type_level_scope(
212                        state, level, scope,
213                    ));
214                }
215                result
216                    .as_mut()
217                    .unwrap()
218                    .props
219                    .insert(name.clone(), tprop.clone());
220                here_sub_there = false;
221            }
222        }
223
224        if httv.indexer.is_some() && tttv.indexer.is_some() {
225            let hindexer = httv.indexer.as_ref().unwrap();
226            let tindexer = tttv.indexer.as_ref().unwrap();
227            if FFlag::LuauReadOnlyIndexers.get() {
228                let index = self.union_type(hindexer.index_type, tindexer.index_type);
229                let mut idx = TableIndexer {
230                    index_type: index,
231                    index_result_type: core::ptr::null(),
232                    is_read_only: false,
233                };
234
235                if hindexer.is_read_only && tindexer.is_read_only {
236                    // Both read-only: covariant -> intersect values, keep read-only.
237                    idx.index_result_type = self
238                        .intersection_type(hindexer.index_result_type, tindexer.index_result_type);
239                    idx.is_read_only = true;
240                } else {
241                    idx.index_result_type = self
242                        .intersection_type(hindexer.index_result_type, tindexer.index_result_type);
243                }
244
245                let here_mode_match = hindexer.is_read_only == idx.is_read_only;
246                let there_mode_match = tindexer.is_read_only == idx.is_read_only;
247                here_sub_there &= here_mode_match
248                    && (hindexer.index_type == index)
249                    && (hindexer.index_result_type == idx.index_result_type);
250                there_sub_here &= there_mode_match
251                    && (tindexer.index_type == index)
252                    && (tindexer.index_result_type == idx.index_result_type);
253
254                if result.is_none() {
255                    result = Some(TableType::table_type_table_state_type_level_scope(
256                        state, level, scope,
257                    ));
258                }
259                result.as_mut().unwrap().indexer = Some(idx);
260            } else {
261                // TODO: What should intersection of indexes be?
262                let index = self.union_type(hindexer.index_type, tindexer.index_type);
263                let index_result =
264                    self.intersection_type(hindexer.index_result_type, tindexer.index_result_type);
265                if result.is_none() {
266                    result = Some(TableType::table_type_table_state_type_level_scope(
267                        state, level, scope,
268                    ));
269                }
270                result.as_mut().unwrap().indexer = Some(TableIndexer {
271                    index_type: index,
272                    index_result_type: index_result,
273                    is_read_only: false,
274                });
275                here_sub_there &=
276                    (hindexer.index_type == index) && (hindexer.index_result_type == index_result);
277                there_sub_here &=
278                    (tindexer.index_type == index) && (tindexer.index_result_type == index_result);
279            }
280        } else if httv.indexer.is_some() {
281            if result.is_none() {
282                result = Some(TableType::table_type_table_state_type_level_scope(
283                    state, level, scope,
284                ));
285            }
286            result.as_mut().unwrap().indexer = httv.indexer;
287            there_sub_here = false;
288        } else if tttv.indexer.is_some() {
289            if result.is_none() {
290                result = Some(TableType::table_type_table_state_type_level_scope(
291                    state, level, scope,
292                ));
293            }
294            result.as_mut().unwrap().indexer = tttv.indexer;
295            here_sub_there = false;
296        }
297
298        let table: TypeId;
299        if here_sub_there {
300            table = htable;
301        } else if there_sub_here {
302            table = ttable;
303        } else if let Some(tt) = result {
304            table = unsafe { (*self.arena).add_type(tt) };
305        } else {
306            table = unsafe {
307                (*self.arena).add_type(TableType::table_type_table_state_type_level_scope(
308                    state, level, scope,
309                ))
310            };
311        }
312
313        if !tmtable.is_null() && !hmtable.is_null() {
314            // NOTE: this assumes metatables are ivariant
315            match self.intersection_of_tables(hmtable, tmtable, seen_table_prop_pairs, seen_set) {
316                Some(mtable) => {
317                    if table == htable && mtable == hmtable {
318                        Some(here)
319                    } else if table == ttable && mtable == tmtable {
320                        Some(there)
321                    } else {
322                        Some(unsafe {
323                            (*self.arena).add_type(MetatableType {
324                                table,
325                                metatable: mtable,
326                                syntheticName: None,
327                            })
328                        })
329                    }
330                }
331                None => None,
332            }
333        } else if !hmtable.is_null() {
334            if table == htable {
335                Some(here)
336            } else {
337                Some(unsafe {
338                    (*self.arena).add_type(MetatableType {
339                        table,
340                        metatable: hmtable,
341                        syntheticName: None,
342                    })
343                })
344            }
345        } else if !tmtable.is_null() {
346            if table == ttable {
347                Some(there)
348            } else {
349                Some(unsafe {
350                    (*self.arena).add_type(MetatableType {
351                        table,
352                        metatable: tmtable,
353                        syntheticName: None,
354                    })
355                })
356            }
357        } else {
358            Some(table)
359        }
360    }
361}