Skip to main content

luaur_analysis/methods/
constraint_generator_union_refinements.rs

1use crate::records::constraint_generator::ConstraintGenerator;
2use crate::records::intersection_type::IntersectionType;
3use crate::records::refinement_partition::RefinementPartition;
4use crate::records::scope::Scope;
5use crate::type_aliases::constraint_v::ConstraintV;
6use crate::type_aliases::refinement_context::RefinementContext;
7use crate::type_aliases::scope_ptr_constraint_generator::ScopePtr;
8use crate::type_aliases::type_id::TypeId;
9use alloc::vec::Vec;
10use luaur_ast::records::location::Location;
11use luaur_common::macros::luau_assert::LUAU_ASSERT;
12
13impl ConstraintGenerator {
14    pub fn union_refinements(
15        &mut self,
16        scope: &ScopePtr,
17        location: Location,
18        lhs: &RefinementContext,
19        rhs: &RefinementContext,
20        dest: *mut RefinementContext,
21        _constraints: *mut Vec<ConstraintV>,
22    ) {
23        let scope_raw = scope.as_ref() as *const Scope as *mut Scope;
24
25        for (def, partition) in lhs.iter() {
26            let rhs_partition = match rhs.get(def) {
27                Some(p) => p,
28                None => continue,
29            };
30
31            LUAU_ASSERT!(!partition.discriminant_types.is_empty());
32            LUAU_ASSERT!(!rhs_partition.discriminant_types.is_empty());
33
34            // C++ `intersect(types)`: 1 -> the sole type, 2 -> makeIntersect, more -> an IntersectionType.
35            let left_discriminant_ty = {
36                let types = &partition.discriminant_types;
37                if types.len() == 1 {
38                    types[0]
39                } else if types.len() == 2 {
40                    self.make_intersect(scope, location, types[0], types[1])
41                } else {
42                    unsafe {
43                        (*self.arena).add_type(IntersectionType {
44                            parts: types.clone(),
45                        })
46                    }
47                }
48            };
49
50            let right_discriminant_ty = {
51                let types = &rhs_partition.discriminant_types;
52                if types.len() == 1 {
53                    types[0]
54                } else if types.len() == 2 {
55                    self.make_intersect(scope, location, types[0], types[1])
56                } else {
57                    unsafe {
58                        (*self.arena).add_type(IntersectionType {
59                            parts: types.clone(),
60                        })
61                    }
62                }
63            };
64
65            let union_ty = self.make_union_scope_ptr_location_type_id_type_id(
66                scope_raw,
67                location,
68                left_discriminant_ty,
69                right_discriminant_ty,
70            );
71
72            let should_append_nil =
73                partition.should_append_nil_type || rhs_partition.should_append_nil_type;
74
75            unsafe {
76                (*dest).insert(*def, RefinementPartition::default());
77                let dest_partition = (*dest).get_mut(def).unwrap();
78                dest_partition.discriminant_types.push(union_ty);
79                dest_partition.should_append_nil_type |= should_append_nil;
80            }
81        }
82    }
83}