Skip to main content

luaur_analysis/methods/
unifier_try_unify_functions.rs

1use crate::functions::finite::finite;
2use crate::functions::get_mutable_type::get_mutable_type_id;
3use crate::functions::has_unification_too_complex::has_unification_too_complex;
4use crate::functions::size_type_pack::size;
5use crate::records::count_mismatch::CountMismatchContext;
6use crate::records::function_type::FunctionType;
7use crate::records::instantiation::Instantiation;
8use crate::records::type_error::TypeError;
9use crate::records::type_mismatch::TypeMismatch;
10use crate::records::unifier::Unifier;
11use crate::type_aliases::type_error_data::TypeErrorData;
12use crate::type_aliases::type_id::TypeId;
13use alloc::format;
14use alloc::string::String;
15use alloc::sync::Arc;
16use core::cmp::min;
17
18impl Unifier {
19    pub fn unifier_try_unify_functions(
20        &mut self,
21        sub_ty: TypeId,
22        super_ty: TypeId,
23        is_function_call: bool,
24    ) {
25        let mut super_function = unsafe { get_mutable_type_id::<FunctionType>(super_ty) };
26        let mut sub_function = unsafe { get_mutable_type_id::<FunctionType>(sub_ty) };
27
28        if super_function.is_null() || sub_function.is_null() {
29            self.ice_string("passed non-function types to unifyFunction");
30            return;
31        }
32
33        let mut num_generics = unsafe { (*super_function).generics.len() };
34        let mut num_generic_packs = unsafe { (*super_function).generic_packs.len() };
35
36        let should_instantiate = unsafe {
37            (num_generics == 0 && !(*sub_function).generics.is_empty())
38                || (num_generic_packs == 0 && !(*sub_function).generic_packs.is_empty())
39        };
40
41        if luaur_common::FFlag::LuauInstantiateInSubtyping.get() && should_instantiate {
42            let mut instantiation = Instantiation::instantiation_new(
43                &self.log as *const _,
44                self.types,
45                self.builtin_types,
46                unsafe { (*self.scope).level },
47                self.scope,
48            );
49
50            if let Some(instantiated) = instantiation.substitute_type_id(sub_ty) {
51                sub_function = unsafe { get_mutable_type_id::<FunctionType>(instantiated) };
52                if sub_function.is_null() {
53                    self.ice_string(
54                        "instantiation made a function type into a non-function type in unifyFunction",
55                    );
56                    return;
57                }
58
59                num_generics = min(unsafe { (*super_function).generics.len() }, unsafe {
60                    (*sub_function).generics.len()
61                });
62                num_generic_packs = min(unsafe { (*super_function).generic_packs.len() }, unsafe {
63                    (*sub_function).generic_packs.len()
64                });
65            } else {
66                self.report_error_location_type_error_data(
67                    self.location,
68                    TypeErrorData::UnificationTooComplex(
69                        crate::records::unification_too_complex::UnificationTooComplex::default(),
70                    ),
71                );
72            }
73        } else if num_generics != unsafe { (*sub_function).generics.len() } {
74            num_generics = min(num_generics, unsafe { (*sub_function).generics.len() });
75            self.report_function_type_mismatch(
76                super_ty,
77                sub_ty,
78                "different number of generic type parameters",
79                None,
80            );
81        }
82
83        if num_generic_packs != unsafe { (*sub_function).generic_packs.len() } {
84            num_generic_packs = min(num_generic_packs, unsafe {
85                (*sub_function).generic_packs.len()
86            });
87            self.report_function_type_mismatch(
88                super_ty,
89                sub_ty,
90                "different number of generic type pack parameters",
91                None,
92            );
93        }
94
95        for i in 0..num_generics {
96            unsafe {
97                let super_generics = &(*super_function).generics;
98                let sub_generics = &(*sub_function).generics;
99                self.log
100                    .push_seen_type_id_type_id(super_generics[i], sub_generics[i]);
101            }
102        }
103
104        for i in 0..num_generic_packs {
105            unsafe {
106                let super_generic_packs = &(*super_function).generic_packs;
107                let sub_generic_packs = &(*sub_function).generic_packs;
108                self.log.push_seen_type_pack_id_type_pack_id(
109                    super_generic_packs[i],
110                    sub_generic_packs[i],
111                );
112            }
113        }
114
115        let context = self.ctx;
116
117        if !is_function_call {
118            let mut inner_state = self.unifier_make_child_unifier();
119
120            inner_state.ctx = CountMismatchContext::Arg;
121            unsafe {
122                inner_state.try_unify_type_pack_id_type_pack_id_bool(
123                    (*super_function).arg_types,
124                    (*sub_function).arg_types,
125                    is_function_call,
126                );
127            }
128
129            let reported = !inner_state.errors.is_empty();
130
131            if let Some(e) = has_unification_too_complex(&inner_state.errors) {
132                self.report_error_type_error(e);
133            } else if !inner_state.errors.is_empty() && inner_state.first_pack_error_pos.is_some() {
134                let reason = format!(
135                    "Argument #{} type is not compatible.",
136                    inner_state.first_pack_error_pos.unwrap()
137                );
138                self.report_function_type_mismatch(
139                    super_ty,
140                    sub_ty,
141                    &reason,
142                    inner_state.errors.first().cloned(),
143                );
144            } else if !inner_state.errors.is_empty() {
145                self.report_function_type_mismatch(
146                    super_ty,
147                    sub_ty,
148                    "",
149                    inner_state.errors.first().cloned(),
150                );
151            }
152
153            inner_state.ctx = CountMismatchContext::FunctionResult;
154            unsafe {
155                inner_state.try_unify_type_pack_id_type_pack_id_bool(
156                    (*sub_function).ret_types,
157                    (*super_function).ret_types,
158                    false,
159                );
160            }
161
162            if !reported {
163                if let Some(e) = has_unification_too_complex(&inner_state.errors) {
164                    self.report_error_type_error(e);
165                } else if !inner_state.errors.is_empty()
166                    && unsafe { size((*super_function).ret_types, core::ptr::null_mut()) == 1 }
167                    && unsafe { finite((*super_function).ret_types, core::ptr::null_mut()) }
168                {
169                    self.report_function_type_mismatch(
170                        super_ty,
171                        sub_ty,
172                        "Return type is not compatible.",
173                        inner_state.errors.first().cloned(),
174                    );
175                } else if !inner_state.errors.is_empty()
176                    && inner_state.first_pack_error_pos.is_some()
177                {
178                    let reason = format!(
179                        "Return #{} type is not compatible.",
180                        inner_state.first_pack_error_pos.unwrap()
181                    );
182                    self.report_function_type_mismatch(
183                        super_ty,
184                        sub_ty,
185                        &reason,
186                        inner_state.errors.first().cloned(),
187                    );
188                } else if !inner_state.errors.is_empty() {
189                    self.report_function_type_mismatch(
190                        super_ty,
191                        sub_ty,
192                        "",
193                        inner_state.errors.first().cloned(),
194                    );
195                }
196            }
197
198            self.log.concat(inner_state.log);
199        } else {
200            self.ctx = CountMismatchContext::Arg;
201            unsafe {
202                self.try_unify_type_pack_id_type_pack_id_bool(
203                    (*super_function).arg_types,
204                    (*sub_function).arg_types,
205                    is_function_call,
206                );
207            }
208
209            self.ctx = CountMismatchContext::FunctionResult;
210            unsafe {
211                self.try_unify_type_pack_id_type_pack_id_bool(
212                    (*sub_function).ret_types,
213                    (*super_function).ret_types,
214                    false,
215                );
216            }
217        }
218
219        super_function = unsafe { get_mutable_type_id::<FunctionType>(super_ty) };
220        sub_function = unsafe { get_mutable_type_id::<FunctionType>(sub_ty) };
221
222        self.ctx = context;
223
224        for i in (0..num_generic_packs).rev() {
225            unsafe {
226                let super_generic_packs = &(*super_function).generic_packs;
227                let sub_generic_packs = &(*sub_function).generic_packs;
228                self.log.pop_seen_type_pack_id_type_pack_id(
229                    super_generic_packs[i],
230                    sub_generic_packs[i],
231                );
232            }
233        }
234
235        for i in (0..num_generics).rev() {
236            unsafe {
237                let super_generics = &(*super_function).generics;
238                let sub_generics = &(*sub_function).generics;
239                self.log
240                    .pop_seen_type_id_type_id(super_generics[i], sub_generics[i]);
241            }
242        }
243    }
244
245    fn report_function_type_mismatch(
246        &mut self,
247        super_ty: TypeId,
248        sub_ty: TypeId,
249        reason: &str,
250        err: Option<TypeError>,
251    ) {
252        let context = self.unifier_mismatch_context();
253        self.report_error_location_type_error_data(
254            self.location,
255            TypeErrorData::TypeMismatch(TypeMismatch {
256                wanted_type: super_ty,
257                given_type: sub_ty,
258                reason: String::from(reason),
259                error: err.map(Arc::new),
260                context,
261            }),
262        );
263    }
264}