Skip to main content

luaur_analysis/methods/
normalizer_union_normal_with_ty.rs

1//! Source: `Analysis/src/Normalize.cpp:1796-1967` (hand-ported)
2use crate::enums::normalization_result::NormalizationResult;
3use crate::functions::assert_invariant::assert_invariant;
4use crate::functions::follow_type::follow_type_id;
5use crate::functions::get_singleton_type::get_singleton_type;
6use crate::functions::get_type_alt_j::get_type_id;
7use crate::functions::is_cacheable_normalize_alt_c::is_cacheable_type_id;
8use crate::functions::tyvar_index::tyvar_index;
9use crate::records::any_type::AnyType;
10use crate::records::blocked_type::BlockedType;
11use crate::records::boolean_singleton::BooleanSingleton;
12use crate::records::extern_type::ExternType;
13use crate::records::free_type::FreeType;
14use crate::records::function_type::FunctionType;
15use crate::records::generic_type::GenericType;
16use crate::records::intersection_type::IntersectionType;
17use crate::records::metatable_type::MetatableType;
18use crate::records::negation_type::NegationType;
19use crate::records::never_type::NeverType;
20use crate::records::no_refine_type::NoRefineType;
21use crate::records::normalized_extern_type::NormalizedExternType;
22use crate::records::normalized_function_type::NormalizedFunctionType;
23use crate::records::normalized_string_type::NormalizedStringType;
24use crate::records::normalized_type::NormalizedType;
25use crate::records::normalizer::Normalizer;
26use crate::records::pending_expansion_type::PendingExpansionType;
27use crate::records::primitive_type::{PrimitiveType, Type as PrimType};
28use crate::records::singleton_type::SingletonType;
29use crate::records::string_singleton::StringSingleton;
30use crate::records::table_type::TableType;
31use crate::records::type_function_instance_type::TypeFunctionInstanceType;
32use crate::records::type_ids::TypeIds;
33use crate::records::unknown_type::UnknownType;
34use crate::type_aliases::error_type::ErrorType;
35use crate::type_aliases::seen_table_prop_pairs::SeenTablePropPairs;
36use crate::type_aliases::type_id::TypeId;
37use alloc::boxed::Box;
38use alloc::collections::BTreeMap;
39use luaur_common::macros::luau_assert::LUAU_ASSERT;
40use luaur_common::records::dense_hash_set::DenseHashSet;
41use luaur_common::FFlag;
42
43/// C++ `Set<TypeId>::erase(key)`. The Rust skeleton maps the seen-set parameter
44/// to `DenseHashSet`, which (faithful to `Luau::DenseHashSet`) cannot erase a
45/// single slot, whereas C++ here uses `Luau::Set` which can. We reproduce the
46/// single-element removal by rebuilding the set without `key`; `clear` preserves
47/// the empty-key sentinel, so the rebuilt set stays valid.
48pub(crate) fn erase_seen(seen: &mut DenseHashSet<TypeId>, key: TypeId) {
49    let kept: Vec<TypeId> = seen.iter().copied().filter(|&k| k != key).collect();
50    seen.clear();
51    for k in kept {
52        seen.insert(k);
53    }
54}
55
56/// RAII guard mirroring C++ `RecursionCounter _rc(&sharedState->counters.recursionCount)`:
57/// increments the shared recursion counter on construction, decrements on drop.
58struct RcGuard {
59    count: *mut i32,
60}
61
62impl RcGuard {
63    fn new(count: *mut i32) -> Self {
64        unsafe {
65            *count += 1;
66        }
67        RcGuard { count }
68    }
69}
70
71impl Drop for RcGuard {
72    fn drop(&mut self) {
73        unsafe {
74            *self.count -= 1;
75        }
76    }
77}
78
79fn fresh_normalized_type(
80    builtin_types: *mut crate::records::builtin_types::BuiltinTypes,
81) -> NormalizedType {
82    let never_type = unsafe { (*builtin_types).neverType };
83    NormalizedType {
84        builtin_types,
85        tops: never_type,
86        booleans: never_type,
87        extern_types: NormalizedExternType {
88            extern_types: BTreeMap::new(),
89            shape_extensions: TypeIds::type_ids(),
90            ordering: Vec::new(),
91        },
92        errors: never_type,
93        nils: never_type,
94        numbers: never_type,
95        integers: never_type,
96        strings: NormalizedStringType::never,
97        threads: never_type,
98        buffers: never_type,
99        tables: TypeIds::type_ids(),
100        functions: NormalizedFunctionType {
101            is_top: false,
102            parts: TypeIds::type_ids(),
103        },
104        tyvars: BTreeMap::new(),
105        is_cacheable: true,
106    }
107}
108
109impl Normalizer {
110    // See above for an explanation of `ignoreSmallerTyvars`.
111    pub fn union_normal_with_ty(
112        &mut self,
113        here: &mut NormalizedType,
114        there: TypeId,
115        seen_table_prop_pairs: &mut SeenTablePropPairs,
116        seen_set_types: &mut DenseHashSet<TypeId>,
117        ignore_smaller_tyvars: i32,
118    ) -> NormalizationResult {
119        let _rc = RcGuard::new(unsafe { &mut (*self.shared_state).counters.recursion_count });
120        if !self.within_resource_limits() {
121            return NormalizationResult::HitLimits;
122        }
123
124        self.consume_fuel();
125
126        let there = unsafe { follow_type_id(there) };
127
128        if !unsafe { get_type_id::<AnyType>(there).is_null() }
129            || !unsafe { get_type_id::<UnknownType>(there).is_null() }
130        {
131            let mut tops = self.union_of_tops(here.tops, there);
132            if !unsafe { get_type_id::<UnknownType>(tops).is_null() }
133                && !unsafe { get_type_id::<ErrorType>(here.errors).is_null() }
134            {
135                tops = unsafe { (*self.builtin_types).anyType };
136            }
137            self.clear_normal(here);
138            here.tops = tops;
139            return NormalizationResult::True;
140        } else if !unsafe { get_type_id::<NeverType>(there).is_null() }
141            || !unsafe { get_type_id::<AnyType>(here.tops).is_null() }
142        {
143            return NormalizationResult::True;
144        } else if !unsafe { get_type_id::<ErrorType>(there).is_null() }
145            && !unsafe { get_type_id::<UnknownType>(here.tops).is_null() }
146        {
147            here.tops = unsafe { (*self.builtin_types).anyType };
148            return NormalizationResult::True;
149        } else if !unsafe { get_type_id::<crate::records::union_type::UnionType>(there).is_null() }
150        {
151            if seen_set_types.contains(&there) {
152                return NormalizationResult::True;
153            }
154            seen_set_types.insert(there);
155
156            let options = unsafe {
157                (*get_type_id::<crate::records::union_type::UnionType>(there))
158                    .options
159                    .clone()
160            };
161            for opt in options {
162                let res =
163                    self.union_normal_with_ty(here, opt, seen_table_prop_pairs, seen_set_types, -1);
164                if res != NormalizationResult::True {
165                    erase_seen(seen_set_types, there);
166                    return res;
167                }
168            }
169
170            erase_seen(seen_set_types, there);
171            return NormalizationResult::True;
172        } else if !unsafe { get_type_id::<IntersectionType>(there).is_null() } {
173            if seen_set_types.contains(&there) {
174                return NormalizationResult::True;
175            }
176            seen_set_types.insert(there);
177
178            let mut norm = fresh_normalized_type(self.builtin_types);
179            norm.tops = unsafe { (*self.builtin_types).unknownType };
180            let parts = unsafe { (*get_type_id::<IntersectionType>(there)).parts.clone() };
181            for part in parts {
182                let res = self.intersect_normal_with_ty(
183                    &mut norm,
184                    part,
185                    seen_table_prop_pairs,
186                    seen_set_types,
187                );
188                if res != NormalizationResult::True {
189                    erase_seen(seen_set_types, there);
190                    return res;
191                }
192            }
193
194            erase_seen(seen_set_types, there);
195
196            return self.union_normals(here, &norm, -1);
197        } else if !unsafe { get_type_id::<UnknownType>(here.tops).is_null() } {
198            return NormalizationResult::True;
199        } else if !unsafe { get_type_id::<GenericType>(there).is_null() }
200            || !unsafe { get_type_id::<FreeType>(there).is_null() }
201            || !unsafe { get_type_id::<BlockedType>(there).is_null() }
202            || !unsafe { get_type_id::<PendingExpansionType>(there).is_null() }
203            || !unsafe { get_type_id::<TypeFunctionInstanceType>(there).is_null() }
204        {
205            if tyvar_index(there) <= ignore_smaller_tyvars {
206                return NormalizationResult::True;
207            }
208            let mut inter = fresh_normalized_type(self.builtin_types);
209            inter.tops = unsafe { (*self.builtin_types).unknownType };
210            here.tyvars.insert(there, Box::new(inter));
211
212            if !is_cacheable_type_id(there) {
213                here.is_cacheable = false;
214            }
215        } else if !unsafe { get_type_id::<FunctionType>(there).is_null() } {
216            self.union_functions_with_function(&mut here.functions, there);
217        } else if !unsafe { get_type_id::<TableType>(there).is_null() }
218            || !unsafe { get_type_id::<MetatableType>(there).is_null() }
219        {
220            self.union_tables_with_table(&mut here.tables, there);
221        } else if !unsafe { get_type_id::<ExternType>(there).is_null() } {
222            self.union_extern_types_with_extern_type_normalized_extern_type_type_id(
223                &mut here.extern_types,
224                there,
225            );
226        } else if !unsafe { get_type_id::<ErrorType>(there).is_null() } {
227            here.errors = there;
228        } else if !unsafe { get_type_id::<PrimitiveType>(there).is_null() } {
229            let ptv = unsafe { &*get_type_id::<PrimitiveType>(there) };
230            match ptv.r#type {
231                PrimType::Boolean => here.booleans = there,
232                PrimType::NilType => here.nils = there,
233                PrimType::Number => here.numbers = there,
234                PrimType::Integer if FFlag::LuauIntegerType2.get() => here.integers = there,
235                PrimType::String => {
236                    crate::methods::normalized_string_type_reset_to_string::normalized_string_type_reset_to_string(
237                        &mut here.strings,
238                    )
239                }
240                PrimType::Thread => here.threads = there,
241                PrimType::Buffer => here.buffers = there,
242                PrimType::Function => here.functions.reset_to_top(),
243                PrimType::Table => {
244                    here.tables.clear();
245                    here.tables.insert_type_id(there);
246                }
247                _ => LUAU_ASSERT!(false),
248            }
249        } else if !unsafe { get_type_id::<SingletonType>(there).is_null() } {
250            let stv = unsafe { get_type_id::<SingletonType>(there) };
251            if !get_singleton_type::<BooleanSingleton>(stv).is_null() {
252                here.booleans = self.union_of_bools(here.booleans, there);
253            } else if !get_singleton_type::<StringSingleton>(stv).is_null() {
254                let sstv = unsafe { &*get_singleton_type::<StringSingleton>(stv) };
255                if here.strings.isCofinite {
256                    if here.strings.singletons.contains_key(&sstv.value) {
257                        here.strings.singletons.remove(&sstv.value);
258                    }
259                } else {
260                    here.strings.singletons.insert(sstv.value.clone(), there);
261                }
262            } else {
263                LUAU_ASSERT!(false);
264            }
265        } else if !unsafe { get_type_id::<NegationType>(there).is_null() } {
266            let ntv_ty = unsafe { (*get_type_id::<NegationType>(there)).ty };
267
268            let there_normal = self.normalize(ntv_ty);
269            let tn = self.negate_normal(&there_normal);
270
271            let mut tn = match tn {
272                Some(t) => t,
273                None => return NormalizationResult::False,
274            };
275
276            let res = self.union_normals(here, &mut tn, -1);
277            if res != NormalizationResult::True {
278                return res;
279            }
280        } else if !unsafe { get_type_id::<PendingExpansionType>(there).is_null() }
281            || !unsafe { get_type_id::<TypeFunctionInstanceType>(there).is_null() }
282            || !unsafe { get_type_id::<NoRefineType>(there).is_null() }
283        {
284            // nothing
285        } else {
286            LUAU_ASSERT!(false);
287        }
288
289        let tyvar_keys: Vec<TypeId> = here.tyvars.keys().copied().collect();
290        for tyvar in tyvar_keys {
291            if let Some(mut intersect) = here.tyvars.remove(&tyvar) {
292                let res = self.union_normal_with_ty(
293                    &mut intersect,
294                    there,
295                    seen_table_prop_pairs,
296                    seen_set_types,
297                    tyvar_index(tyvar),
298                );
299                here.tyvars.insert(tyvar, intersect);
300                if res != NormalizationResult::True {
301                    return res;
302                }
303            }
304        }
305
306        assert_invariant(here);
307        NormalizationResult::True
308    }
309}