luaur_analysis/methods/
constraint_generator_check_binary.rs1use crate::enums::type_context::TypeContext;
4use crate::functions::follow_type::follow_type_id;
5use crate::functions::get_type_alt_j::get_type_id;
6use crate::functions::has_tag_type_alt_b::has_tag;
7use crate::functions::match_type_guard::match_type_guard;
8use crate::records::constraint_generator::ConstraintGenerator;
9use crate::records::extern_type::ExternType;
10use crate::records::in_conditional_context::InConditionalContext;
11use crate::records::union_type::UnionType;
12use crate::type_aliases::refinement_id_refinement::RefinementId;
13use crate::type_aliases::scope_ptr_constraint_generator::ScopePtr;
14use crate::type_aliases::type_id::TypeId;
15use luaur_ast::records::ast_expr::AstExpr;
16use luaur_ast::records::ast_expr_binary::{AstExprBinary, AstExprBinaryOp};
17use luaur_ast::records::ast_node::AstNode;
18
19impl ConstraintGenerator {
20 pub fn check_binary(
21 &mut self,
22 scope: &ScopePtr,
23 op: AstExprBinaryOp,
24 left: *mut AstExpr,
25 right: *mut AstExpr,
26 expected_type: Option<TypeId>,
27 ) -> (TypeId, TypeId, RefinementId) {
28 unsafe {
29 let _in_context = if op != AstExprBinary::And
30 && op != AstExprBinary::Or
31 && op != AstExprBinary::CompareEq
32 && op != AstExprBinary::CompareNe
33 {
34 Some(InConditionalContext::new(
35 &mut self.type_context,
36 TypeContext::Default,
37 ))
38 } else {
39 None
40 };
41
42 if op == AstExprBinary::And {
43 let mut relaxed_expected_lhs: Option<TypeId> = None;
44
45 if let Some(exp) = expected_type {
46 relaxed_expected_lhs = Some((*self.arena).add_type(UnionType {
47 options: alloc::vec![(*self.builtin_types).falsyType, exp],
48 }));
49 }
50
51 let left_inf = self.check_scope_ptr_ast_expr_optional_type_id(
52 scope,
53 left,
54 relaxed_expected_lhs,
55 );
56 let left_type = left_inf.ty;
57 let left_refinement = left_inf.refinement;
58
59 let right_scope = self.child_scope(right as *mut AstNode, scope);
60 self.apply_refinements(&right_scope, (*right).base.location, left_refinement);
61 let right_inf = self.check_scope_ptr_ast_expr_optional_type_id(
62 &right_scope,
63 right,
64 expected_type,
65 );
66 let right_type = right_inf.ty;
67 let right_refinement = right_inf.refinement;
68
69 let conj = self
70 .refinement_arena
71 .conjunction_refinement_id_refinement_id(left_refinement, right_refinement);
72 return (left_type, right_type, conj);
73 } else if op == AstExprBinary::Or {
74 let mut relaxed_expected_lhs: Option<TypeId> = None;
75
76 if let Some(exp) = expected_type {
77 relaxed_expected_lhs = Some((*self.arena).add_type(UnionType {
78 options: alloc::vec![(*self.builtin_types).falsyType, exp],
79 }));
80 }
81
82 let left_inf = self.check_scope_ptr_ast_expr_optional_type_id(
83 scope,
84 left,
85 relaxed_expected_lhs,
86 );
87 let left_type = left_inf.ty;
88 let left_refinement = left_inf.refinement;
89
90 let right_scope = self.child_scope(right as *mut AstNode, scope);
91 let negated = self
92 .refinement_arena
93 .negation_refinement_id(left_refinement);
94 self.apply_refinements(&right_scope, (*right).base.location, negated);
95 let right_inf = self.check_scope_ptr_ast_expr_optional_type_id(
96 &right_scope,
97 right,
98 expected_type,
99 );
100 let right_type = right_inf.ty;
101 let right_refinement = right_inf.refinement;
102
103 let disj = self
104 .refinement_arena
105 .disjunction_refinement_id_refinement_id(left_refinement, right_refinement);
106 return (left_type, right_type, disj);
107 } else if let Some(typeguard) = match_type_guard(op as i32, left, right) {
108 let left_type = self.check_scope_ptr_ast_expr(scope, left).ty;
109 let right_type = self.check_scope_ptr_ast_expr(scope, right).ty;
110
111 let key = (*self.dfg).get_refinement_key(typeguard.target as *const AstExpr);
112 if key.is_null() {
113 return (left_type, right_type, core::ptr::null_mut());
114 }
115
116 let mut discriminant_ty: TypeId = (*self.builtin_types).neverType;
117 let guard_type = typeguard.r#type();
118 if guard_type == "nil" {
119 discriminant_ty = (*self.builtin_types).nilType;
120 } else if guard_type == "string" {
121 discriminant_ty = (*self.builtin_types).stringType;
122 } else if guard_type == "number" {
123 discriminant_ty = (*self.builtin_types).numberType;
124 } else if guard_type == "integer" {
125 discriminant_ty = (*self.builtin_types).integerType;
126 } else if guard_type == "boolean" {
127 discriminant_ty = (*self.builtin_types).booleanType;
128 } else if guard_type == "thread" {
129 discriminant_ty = (*self.builtin_types).threadType;
130 } else if guard_type == "buffer" {
131 discriminant_ty = (*self.builtin_types).bufferType;
132 } else if guard_type == "table" {
133 discriminant_ty = (*self.builtin_types).tableType;
134 } else if guard_type == "function" {
135 discriminant_ty = (*self.builtin_types).functionType;
136 } else if guard_type == "userdata" {
137 discriminant_ty = (*self.builtin_types).externType;
139 } else if guard_type == "vector" && !typeguard.isTypeof() {
140 let type_fun = self
142 .global_scope
143 .as_ref()
144 .unwrap()
145 .lookup_type(&alloc::string::String::from("vector"));
146 if let Some(type_fun) = type_fun {
147 discriminant_ty = follow_type_id(type_fun.r#type());
148 }
149 } else if !typeguard.isTypeof() {
150 discriminant_ty = (*self.builtin_types).neverType;
151 } else {
152 let type_fun = self
153 .global_scope
154 .as_ref()
155 .unwrap()
156 .lookup_type(&alloc::string::String::from(guard_type));
157 if let Some(type_fun) = type_fun {
158 if type_fun.type_params().is_empty()
159 && type_fun.type_pack_params().is_empty()
160 {
161 let ty = follow_type_id(type_fun.r#type());
162
163 let etv = get_type_id::<ExternType>(ty);
165 if !etv.is_null()
166 && ((*etv).parent == Some((*self.builtin_types).externType)
167 || has_tag(ty, "typeofRoot"))
168 {
169 discriminant_ty = ty;
170 }
171 }
172 }
173 }
174
175 let proposition = self
176 .refinement_arena
177 .proposition_refinement_key_type_id(key, discriminant_ty);
178 if op == AstExprBinary::CompareEq {
179 return (left_type, right_type, proposition);
180 } else if op == AstExprBinary::CompareNe {
181 let negated = self.refinement_arena.negation_refinement_id(proposition);
182 return (left_type, right_type, negated);
183 } else {
184 (*self.ice)
185 .ice_string("matchTypeGuard should only return a Some under `==` or `~=`!");
186 return (left_type, right_type, core::ptr::null_mut());
187 }
188 } else if op == AstExprBinary::CompareEq || op == AstExprBinary::CompareNe {
189 let left_type = self
192 .check_scope_ptr_ast_expr_optional_type_id_bool(scope, left, None, true)
193 .ty;
194 let right_type = self
195 .check_scope_ptr_ast_expr_optional_type_id_bool(scope, right, None, true)
196 .ty;
197
198 let left_key = (*self.dfg).get_refinement_key(left as *const AstExpr);
199 let right_key = (*self.dfg).get_refinement_key(right as *const AstExpr);
200 let mut left_refinement = self
201 .refinement_arena
202 .proposition_refinement_key_type_id(left_key, right_type);
203 let mut right_refinement = self
204 .refinement_arena
205 .proposition_refinement_key_type_id(right_key, left_type);
206
207 if op == AstExprBinary::CompareNe {
208 left_refinement = self
209 .refinement_arena
210 .negation_refinement_id(left_refinement);
211 right_refinement = self
212 .refinement_arena
213 .negation_refinement_id(right_refinement);
214 }
215
216 let equiv = self
217 .refinement_arena
218 .equivalence_refinement_id_refinement_id(left_refinement, right_refinement);
219 return (left_type, right_type, equiv);
220 } else {
221 let left_type = self.check_scope_ptr_ast_expr(scope, left).ty;
222 let right_type = self.check_scope_ptr_ast_expr(scope, right).ty;
223 return (left_type, right_type, core::ptr::null_mut());
224 }
225 }
226 }
227}