Skip to main content

luaur_analysis/methods/
constraint_generator_check_binary.rs

1//! Source: `Analysis/src/ConstraintGenerator.cpp:3606-3731` (hand-ported)
2//! C++ `std::tuple<TypeId, TypeId, RefinementId> ConstraintGenerator::checkBinary(...)`.
3use crate::enums::type_context::TypeContext;
4use crate::functions::follow_type::follow_type_id;
5use crate::functions::get_type_alt_j::get_type_id;
6use crate::functions::has_tag_type_alt_b::has_tag;
7use crate::functions::match_type_guard::match_type_guard;
8use crate::records::constraint_generator::ConstraintGenerator;
9use crate::records::extern_type::ExternType;
10use crate::records::in_conditional_context::InConditionalContext;
11use crate::records::union_type::UnionType;
12use crate::type_aliases::refinement_id_refinement::RefinementId;
13use crate::type_aliases::scope_ptr_constraint_generator::ScopePtr;
14use crate::type_aliases::type_id::TypeId;
15use luaur_ast::records::ast_expr::AstExpr;
16use luaur_ast::records::ast_expr_binary::{AstExprBinary, AstExprBinaryOp};
17use luaur_ast::records::ast_node::AstNode;
18
19impl ConstraintGenerator {
20    pub fn check_binary(
21        &mut self,
22        scope: &ScopePtr,
23        op: AstExprBinaryOp,
24        left: *mut AstExpr,
25        right: *mut AstExpr,
26        expected_type: Option<TypeId>,
27    ) -> (TypeId, TypeId, RefinementId) {
28        unsafe {
29            let _in_context = if op != AstExprBinary::And
30                && op != AstExprBinary::Or
31                && op != AstExprBinary::CompareEq
32                && op != AstExprBinary::CompareNe
33            {
34                Some(InConditionalContext::new(
35                    &mut self.type_context,
36                    TypeContext::Default,
37                ))
38            } else {
39                None
40            };
41
42            if op == AstExprBinary::And {
43                let mut relaxed_expected_lhs: Option<TypeId> = None;
44
45                if let Some(exp) = expected_type {
46                    relaxed_expected_lhs = Some((*self.arena).add_type(UnionType {
47                        options: alloc::vec![(*self.builtin_types).falsyType, exp],
48                    }));
49                }
50
51                let left_inf = self.check_scope_ptr_ast_expr_optional_type_id(
52                    scope,
53                    left,
54                    relaxed_expected_lhs,
55                );
56                let left_type = left_inf.ty;
57                let left_refinement = left_inf.refinement;
58
59                let right_scope = self.child_scope(right as *mut AstNode, scope);
60                self.apply_refinements(&right_scope, (*right).base.location, left_refinement);
61                let right_inf = self.check_scope_ptr_ast_expr_optional_type_id(
62                    &right_scope,
63                    right,
64                    expected_type,
65                );
66                let right_type = right_inf.ty;
67                let right_refinement = right_inf.refinement;
68
69                let conj = self
70                    .refinement_arena
71                    .conjunction_refinement_id_refinement_id(left_refinement, right_refinement);
72                return (left_type, right_type, conj);
73            } else if op == AstExprBinary::Or {
74                let mut relaxed_expected_lhs: Option<TypeId> = None;
75
76                if let Some(exp) = expected_type {
77                    relaxed_expected_lhs = Some((*self.arena).add_type(UnionType {
78                        options: alloc::vec![(*self.builtin_types).falsyType, exp],
79                    }));
80                }
81
82                let left_inf = self.check_scope_ptr_ast_expr_optional_type_id(
83                    scope,
84                    left,
85                    relaxed_expected_lhs,
86                );
87                let left_type = left_inf.ty;
88                let left_refinement = left_inf.refinement;
89
90                let right_scope = self.child_scope(right as *mut AstNode, scope);
91                let negated = self
92                    .refinement_arena
93                    .negation_refinement_id(left_refinement);
94                self.apply_refinements(&right_scope, (*right).base.location, negated);
95                let right_inf = self.check_scope_ptr_ast_expr_optional_type_id(
96                    &right_scope,
97                    right,
98                    expected_type,
99                );
100                let right_type = right_inf.ty;
101                let right_refinement = right_inf.refinement;
102
103                let disj = self
104                    .refinement_arena
105                    .disjunction_refinement_id_refinement_id(left_refinement, right_refinement);
106                return (left_type, right_type, disj);
107            } else if let Some(typeguard) = match_type_guard(op as i32, left, right) {
108                let left_type = self.check_scope_ptr_ast_expr(scope, left).ty;
109                let right_type = self.check_scope_ptr_ast_expr(scope, right).ty;
110
111                let key = (*self.dfg).get_refinement_key(typeguard.target as *const AstExpr);
112                if key.is_null() {
113                    return (left_type, right_type, core::ptr::null_mut());
114                }
115
116                let mut discriminant_ty: TypeId = (*self.builtin_types).neverType;
117                let guard_type = typeguard.r#type();
118                if guard_type == "nil" {
119                    discriminant_ty = (*self.builtin_types).nilType;
120                } else if guard_type == "string" {
121                    discriminant_ty = (*self.builtin_types).stringType;
122                } else if guard_type == "number" {
123                    discriminant_ty = (*self.builtin_types).numberType;
124                } else if guard_type == "integer" {
125                    discriminant_ty = (*self.builtin_types).integerType;
126                } else if guard_type == "boolean" {
127                    discriminant_ty = (*self.builtin_types).booleanType;
128                } else if guard_type == "thread" {
129                    discriminant_ty = (*self.builtin_types).threadType;
130                } else if guard_type == "buffer" {
131                    discriminant_ty = (*self.builtin_types).bufferType;
132                } else if guard_type == "table" {
133                    discriminant_ty = (*self.builtin_types).tableType;
134                } else if guard_type == "function" {
135                    discriminant_ty = (*self.builtin_types).functionType;
136                } else if guard_type == "userdata" {
137                    // For now, we don't really care about being accurate with userdata if the typeguard was using typeof.
138                    discriminant_ty = (*self.builtin_types).externType;
139                } else if guard_type == "vector" && !typeguard.isTypeof() {
140                    // `vector` is defined in EmbeddedBuiltinDefinitions, not as an actual built-in type
141                    let type_fun = self
142                        .global_scope
143                        .as_ref()
144                        .unwrap()
145                        .lookup_type(&alloc::string::String::from("vector"));
146                    if let Some(type_fun) = type_fun {
147                        discriminant_ty = follow_type_id(type_fun.r#type());
148                    }
149                } else if !typeguard.isTypeof() {
150                    discriminant_ty = (*self.builtin_types).neverType;
151                } else {
152                    let type_fun = self
153                        .global_scope
154                        .as_ref()
155                        .unwrap()
156                        .lookup_type(&alloc::string::String::from(guard_type));
157                    if let Some(type_fun) = type_fun {
158                        if type_fun.type_params().is_empty()
159                            && type_fun.type_pack_params().is_empty()
160                        {
161                            let ty = follow_type_id(type_fun.r#type());
162
163                            // We're only interested in the root type of any extern type.
164                            let etv = get_type_id::<ExternType>(ty);
165                            if !etv.is_null()
166                                && ((*etv).parent == Some((*self.builtin_types).externType)
167                                    || has_tag(ty, "typeofRoot"))
168                            {
169                                discriminant_ty = ty;
170                            }
171                        }
172                    }
173                }
174
175                let proposition = self
176                    .refinement_arena
177                    .proposition_refinement_key_type_id(key, discriminant_ty);
178                if op == AstExprBinary::CompareEq {
179                    return (left_type, right_type, proposition);
180                } else if op == AstExprBinary::CompareNe {
181                    let negated = self.refinement_arena.negation_refinement_id(proposition);
182                    return (left_type, right_type, negated);
183                } else {
184                    (*self.ice)
185                        .ice_string("matchTypeGuard should only return a Some under `==` or `~=`!");
186                    return (left_type, right_type, core::ptr::null_mut());
187                }
188            } else if op == AstExprBinary::CompareEq || op == AstExprBinary::CompareNe {
189                // We are checking a binary expression of the form a op b
190                // Just because a op b is expected to return a bool, doesn't mean a, b are expected to be bools too
191                let left_type = self
192                    .check_scope_ptr_ast_expr_optional_type_id_bool(scope, left, None, true)
193                    .ty;
194                let right_type = self
195                    .check_scope_ptr_ast_expr_optional_type_id_bool(scope, right, None, true)
196                    .ty;
197
198                let left_key = (*self.dfg).get_refinement_key(left as *const AstExpr);
199                let right_key = (*self.dfg).get_refinement_key(right as *const AstExpr);
200                let mut left_refinement = self
201                    .refinement_arena
202                    .proposition_refinement_key_type_id(left_key, right_type);
203                let mut right_refinement = self
204                    .refinement_arena
205                    .proposition_refinement_key_type_id(right_key, left_type);
206
207                if op == AstExprBinary::CompareNe {
208                    left_refinement = self
209                        .refinement_arena
210                        .negation_refinement_id(left_refinement);
211                    right_refinement = self
212                        .refinement_arena
213                        .negation_refinement_id(right_refinement);
214                }
215
216                let equiv = self
217                    .refinement_arena
218                    .equivalence_refinement_id_refinement_id(left_refinement, right_refinement);
219                return (left_type, right_type, equiv);
220            } else {
221                let left_type = self.check_scope_ptr_ast_expr(scope, left).ty;
222                let right_type = self.check_scope_ptr_ast_expr(scope, right).ty;
223                return (left_type, right_type, core::ptr::null_mut());
224            }
225        }
226    }
227}