Skip to main content

luaur_analysis/methods/
unifier_2_occurs_check.rs

1//! Source: `Analysis/src/Unifier2.cpp:825-871` — `Unifier2::occursCheck(DenseHashSet<TypeId>&, TypeId, TypeId)`.
2
3use crate::enums::occurs_check_result::OccursCheckResult;
4use crate::functions::follow_type::follow_type_id;
5use crate::functions::get_type_alt_j::get_type_id;
6use crate::records::error_type::ErrorType;
7use crate::records::free_type::FreeType;
8use crate::records::intersection_type::IntersectionType;
9use crate::records::recursion_limiter::RecursionLimiter;
10use crate::records::unifier_2::Unifier2;
11use crate::records::union_type::UnionType;
12use crate::type_aliases::type_id::TypeId;
13use luaur_common::records::dense_hash_set::DenseHashSet;
14
15impl Unifier2 {
16    pub fn occurs_check(
17        &mut self,
18        seen: &mut DenseHashSet<TypeId>,
19        needle: TypeId,
20        haystack: TypeId,
21    ) -> OccursCheckResult {
22        let mut _ra = RecursionLimiter {
23            base: unsafe { core::mem::zeroed() },
24            native_stack_guard: unsafe { core::mem::zeroed() },
25        };
26        _ra.recursion_limiter_recursion_limiter(
27            "Unifier2::occursCheck",
28            &mut self.recursion_count as *mut i32 as *mut core::ffi::c_int,
29            self.recursion_limit as core::ffi::c_int,
30        );
31
32        let mut occurrence = OccursCheckResult::Pass;
33
34        let needle = unsafe { follow_type_id(needle) };
35        let haystack = unsafe { follow_type_id(haystack) };
36
37        if seen.find(&haystack).is_some() {
38            return OccursCheckResult::Pass;
39        }
40
41        seen.insert(haystack);
42
43        if !unsafe { get_type_id::<ErrorType>(needle) }.is_null() {
44            return OccursCheckResult::Pass;
45        }
46
47        if unsafe { get_type_id::<FreeType>(needle) }.is_null() {
48            unsafe { (*self.ice.as_ptr()).ice_string("Expected needle to be free") };
49        }
50
51        if needle == haystack {
52            return OccursCheckResult::Fail;
53        }
54
55        let haystack_free = unsafe { get_type_id::<FreeType>(haystack) };
56        if !haystack_free.is_null() {
57            let lower = unsafe { (*haystack_free).lower_bound };
58            let upper = unsafe { (*haystack_free).upper_bound };
59            if self.occurs_check(seen, needle, lower) == OccursCheckResult::Fail {
60                occurrence = OccursCheckResult::Fail;
61            }
62            if self.occurs_check(seen, needle, upper) == OccursCheckResult::Fail {
63                occurrence = OccursCheckResult::Fail;
64            }
65        } else {
66            let ut = unsafe { get_type_id::<UnionType>(haystack) };
67            if !ut.is_null() {
68                let options: alloc::vec::Vec<TypeId> = unsafe { (*ut).options.clone() };
69                for ty in options {
70                    if self.occurs_check(seen, needle, ty) == OccursCheckResult::Fail {
71                        occurrence = OccursCheckResult::Fail;
72                    }
73                }
74            } else {
75                let it = unsafe { get_type_id::<IntersectionType>(haystack) };
76                if !it.is_null() {
77                    let parts: alloc::vec::Vec<TypeId> = unsafe { (*it).parts.clone() };
78                    for ty in parts {
79                        if self.occurs_check(seen, needle, ty) == OccursCheckResult::Fail {
80                            occurrence = OccursCheckResult::Fail;
81                        }
82                    }
83                }
84            }
85        }
86
87        occurrence
88    }
89}