Skip to main content

cairo_lang_semantic/expr/inference/
conform.rs

1use std::hash::Hash;
2
3use cairo_lang_defs::ids::{LanguageElementId, TraitConstantId, TraitTypeId};
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
6use cairo_lang_utils::Intern;
7use cairo_lang_utils::ordered_hash_map::{Entry, OrderedHashMap};
8use itertools::zip_eq;
9
10use super::canonic::{NoError, ResultNoErrEx};
11use super::{
12    ErrorSet, ImplVarId, ImplVarTraitItemMappings, Inference, InferenceError, InferenceResult,
13    InferenceVar, LocalTypeVarId, TypeVar,
14};
15use crate::corelib::never_ty;
16use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
17use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
18use crate::items::functions::{GenericFunctionId, ImplGenericFunctionId};
19use crate::items::imp::{
20    ImplId, ImplImplId, ImplLongId, ImplLookupContext, ImplSemantic, NegativeImplId,
21    NegativeImplLongId,
22};
23use crate::items::trt::{ConcreteTraitImplId, TraitSemantic};
24use crate::substitution::SemanticRewriter;
25use crate::types::{ClosureTypeLongId, ImplTypeId, peel_snapshots};
26use crate::{
27    ConcreteFunction, ConcreteImplLongId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId,
28    FunctionId, FunctionLongId, GenericArgumentId, TypeId, TypeLongId,
29};
30
31/// Functions for conforming semantic objects with each other.
32pub trait InferenceConform<'db> {
33    fn conform_ty(&mut self, ty0: TypeId<'db>, ty1: TypeId<'db>) -> InferenceResult<TypeId<'db>>;
34    fn conform_ty_ex(
35        &mut self,
36        ty0: TypeId<'db>,
37        ty1: TypeId<'db>,
38        ty0_is_self: bool,
39    ) -> InferenceResult<(TypeId<'db>, usize)>;
40    fn conform_const(
41        &mut self,
42        ty0: ConstValueId<'db>,
43        ty1: ConstValueId<'db>,
44    ) -> InferenceResult<ConstValueId<'db>>;
45    fn maybe_peel_snapshots(
46        &mut self,
47        ty0_is_self: bool,
48        ty1: TypeId<'db>,
49    ) -> (usize, TypeLongId<'db>);
50    fn conform_generic_args(
51        &mut self,
52        gargs0: &[GenericArgumentId<'db>],
53        gargs1: &[GenericArgumentId<'db>],
54    ) -> InferenceResult<Vec<GenericArgumentId<'db>>>;
55    fn conform_generic_arg(
56        &mut self,
57        garg0: GenericArgumentId<'db>,
58        garg1: GenericArgumentId<'db>,
59    ) -> InferenceResult<GenericArgumentId<'db>>;
60    fn conform_impl(
61        &mut self,
62        impl0: ImplId<'db>,
63        impl1: ImplId<'db>,
64    ) -> InferenceResult<ImplId<'db>>;
65    fn conform_neg_impl(
66        &mut self,
67        neg_impl0: NegativeImplId<'db>,
68        neg_impl1: NegativeImplId<'db>,
69    ) -> InferenceResult<NegativeImplId<'db>>;
70    fn conform_traits(
71        &mut self,
72        trt0: ConcreteTraitId<'db>,
73        trt1: ConcreteTraitId<'db>,
74    ) -> InferenceResult<ConcreteTraitId<'db>>;
75    fn conform_generic_function(
76        &mut self,
77        trt0: GenericFunctionId<'db>,
78        trt1: GenericFunctionId<'db>,
79    ) -> InferenceResult<GenericFunctionId<'db>>;
80    fn ty_contains_var(&mut self, ty: TypeId<'db>, var: InferenceVar) -> bool;
81    fn generic_args_contain_var(
82        &mut self,
83        generic_args: &[GenericArgumentId<'db>],
84        var: InferenceVar,
85    ) -> bool;
86    fn impl_contains_var(&mut self, impl_id: ImplId<'db>, var: InferenceVar) -> bool;
87    fn negative_impl_contains_var(
88        &mut self,
89        neg_impl_id: NegativeImplId<'db>,
90        var: InferenceVar,
91    ) -> bool;
92    fn function_contains_var(&mut self, function_id: FunctionId<'db>, var: InferenceVar) -> bool;
93}
94
95impl<'db> InferenceConform<'db> for Inference<'db, '_> {
96    /// Conforms ty0 to ty1. Should be called when ty0 should be coerced to ty1. Not symmetric.
97    /// Returns the reduced type for ty0, or an error if the type is no coercible.
98    fn conform_ty(&mut self, ty0: TypeId<'db>, ty1: TypeId<'db>) -> InferenceResult<TypeId<'db>> {
99        Ok(self.conform_ty_ex(ty0, ty1, false)?.0)
100    }
101
102    /// Same as conform_ty but supports adding snapshots to ty0 if `ty0_is_self` is true.
103    /// Returns the reduced type for ty0 and the number of snapshots that needs to be added
104    /// for the types to conform.
105    fn conform_ty_ex(
106        &mut self,
107        ty0: TypeId<'db>,
108        ty1: TypeId<'db>,
109        ty0_is_self: bool,
110    ) -> InferenceResult<(TypeId<'db>, usize)> {
111        let ty0 = self.rewrite(ty0).no_err();
112        let ty1 = self.rewrite(ty1).no_err();
113        if ty0 == never_ty(self.db) || ty0.is_missing(self.db) {
114            return Ok((ty1, 0));
115        }
116        if ty0 == ty1 {
117            return Ok((ty0, 0));
118        }
119        let long_ty1 = ty1.long(self.db);
120        match long_ty1 {
121            TypeLongId::Var(var) => return Ok((self.assign_ty(*var, ty0)?, 0)),
122            TypeLongId::Missing(_) => return Ok((ty1, 0)),
123            TypeLongId::Snapshot(inner_ty) if ty0_is_self => {
124                if *inner_ty == ty0 {
125                    return Ok((ty1, 1));
126                }
127                if !matches!(ty0.long(self.db), TypeLongId::Snapshot(_))
128                    && let TypeLongId::Var(var) = inner_ty.long(self.db)
129                {
130                    return Ok((self.assign_ty(*var, ty0)?, 1));
131                }
132            }
133            TypeLongId::ImplType(impl_type) => {
134                if let Some(ty) = self.impl_type_bounds.get(&(*impl_type).into()) {
135                    return self.conform_ty_ex(ty0, *ty, ty0_is_self);
136                }
137            }
138            _ => {}
139        }
140        let long_ty0 = ty0.long(self.db);
141
142        match long_ty0 {
143            TypeLongId::Concrete(concrete0) => {
144                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
145                let TypeLongId::Concrete(concrete1) = long_ty1 else {
146                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
147                };
148                if concrete0.generic_type(self.db) != concrete1.generic_type(self.db) {
149                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
150                }
151                let gargs0 = concrete0.generic_args(self.db);
152                let gargs1 = concrete1.generic_args(self.db);
153                let gargs = self.conform_generic_args(&gargs0, &gargs1)?;
154                let long_ty = TypeLongId::Concrete(ConcreteTypeId::new(
155                    self.db,
156                    concrete0.generic_type(self.db),
157                    gargs,
158                ));
159                Ok((long_ty.intern(self.db), n_snapshots))
160            }
161            TypeLongId::Tuple(tys0) => {
162                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
163                let TypeLongId::Tuple(tys1) = long_ty1 else {
164                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
165                };
166                if tys0.len() != tys1.len() {
167                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
168                }
169                let tys = zip_eq(tys0, tys1)
170                    .map(|(subty0, subty1)| self.conform_ty(*subty0, subty1))
171                    .collect::<Result<Vec<_>, _>>()?;
172                Ok((TypeLongId::Tuple(tys).intern(self.db), n_snapshots))
173            }
174            TypeLongId::Closure(closure0) => {
175                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
176                let TypeLongId::Closure(closure1) = long_ty1 else {
177                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
178                };
179                if closure0.params_location != closure1.params_location {
180                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
181                }
182                let param_tys = zip_eq(closure0.param_tys.clone(), closure1.param_tys)
183                    .map(|(subty0, subty1)| self.conform_ty(subty0, subty1))
184                    .collect::<Result<Vec<_>, _>>()?;
185                let captured_types =
186                    zip_eq(closure0.captured_types.clone(), closure1.captured_types)
187                        .map(|(subty0, subty1)| self.conform_ty(subty0, subty1))
188                        .collect::<Result<Vec<_>, _>>()?;
189                let ret_ty = self.conform_ty(closure0.ret_ty, closure1.ret_ty)?;
190                Ok((
191                    TypeLongId::Closure(ClosureTypeLongId {
192                        param_tys,
193                        ret_ty,
194                        captured_types,
195                        params_location: closure0.params_location,
196                        parent_function: closure0.parent_function,
197                    })
198                    .intern(self.db),
199                    n_snapshots,
200                ))
201            }
202            TypeLongId::FixedSizeArray { type_id, size } => {
203                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
204                let TypeLongId::FixedSizeArray { type_id: type_id1, size: size1 } = long_ty1 else {
205                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
206                };
207                let size = self.conform_const(*size, size1)?;
208                let ty = self.conform_ty(*type_id, type_id1)?;
209                Ok((TypeLongId::FixedSizeArray { type_id: ty, size }.intern(self.db), n_snapshots))
210            }
211            TypeLongId::Snapshot(inner_ty0) => {
212                let TypeLongId::Snapshot(inner_ty1) = long_ty1 else {
213                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
214                };
215                let (ty, n_snapshots) = self.conform_ty_ex(*inner_ty0, *inner_ty1, ty0_is_self)?;
216                Ok((TypeLongId::Snapshot(ty).intern(self.db), n_snapshots))
217            }
218            TypeLongId::GenericParameter(_) => {
219                Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }))
220            }
221            TypeLongId::Var(var) => Ok((self.assign_ty(*var, ty1)?, 0)),
222            TypeLongId::ImplType(impl_type) => {
223                if let Some(ty) = self.impl_type_bounds.get(&(*impl_type).into()) {
224                    return self.conform_ty_ex(*ty, ty1, ty0_is_self);
225                }
226                Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }))
227            }
228            TypeLongId::Missing(_) => Ok((ty0, 0)),
229            TypeLongId::Coupon(function_id0) => {
230                let TypeLongId::Coupon(function_id1) = long_ty1 else {
231                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
232                };
233
234                let func0 = function_id0.long(self.db).function.clone();
235                let func1 = function_id1.long(self.db).function.clone();
236
237                let generic_function =
238                    self.conform_generic_function(func0.generic_function, func1.generic_function)?;
239
240                if func0.generic_args.len() != func1.generic_args.len() {
241                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
242                }
243
244                let generic_args =
245                    self.conform_generic_args(&func0.generic_args, &func1.generic_args)?;
246
247                Ok((
248                    TypeLongId::Coupon(
249                        FunctionLongId {
250                            function: ConcreteFunction { generic_function, generic_args },
251                        }
252                        .intern(self.db),
253                    )
254                    .intern(self.db),
255                    0,
256                ))
257            }
258        }
259    }
260
261    /// Conforms id0 to id1. Should be called when id0 should be coerced to id1. Not symmetric.
262    /// Returns the reduced const for id0, or an error if the const is no coercible.
263    fn conform_const(
264        &mut self,
265        id0: ConstValueId<'db>,
266        id1: ConstValueId<'db>,
267    ) -> InferenceResult<ConstValueId<'db>> {
268        let id0 = self.rewrite(id0).no_err();
269        let id1 = self.rewrite(id1).no_err();
270        self.conform_ty(id0.ty(self.db).unwrap(), id1.ty(self.db).unwrap())?;
271        if id0 == id1 {
272            return Ok(id0);
273        }
274        let const_value0 = id0.long(self.db);
275        if matches!(const_value0, ConstValue::Missing(_)) {
276            return Ok(id1);
277        }
278        match id1.long(self.db) {
279            ConstValue::Missing(_) => return Ok(id1),
280            ConstValue::Var(var, _) => return self.assign_const(*var, id0),
281            _ => {}
282        }
283        match const_value0 {
284            ConstValue::Var(var, _) => Ok(self.assign_const(*var, id1)?),
285            ConstValue::ImplConstant(_) => {
286                Err(self.set_error(InferenceError::ConstKindMismatch { const0: id0, const1: id1 }))
287            }
288            _ => {
289                Err(self.set_error(InferenceError::ConstKindMismatch { const0: id0, const1: id1 }))
290            }
291        }
292    }
293
294    // Conditionally peels snapshots.
295    fn maybe_peel_snapshots(
296        &mut self,
297        ty0_is_self: bool,
298        ty1: TypeId<'db>,
299    ) -> (usize, TypeLongId<'db>) {
300        let (n_snapshots, long_ty1) =
301            if ty0_is_self { peel_snapshots(self.db, ty1) } else { (0, ty1.long(self.db).clone()) };
302        (n_snapshots, long_ty1)
303    }
304
305    /// Conforms generic args. See `conform_ty()`.
306    fn conform_generic_args(
307        &mut self,
308        gargs0: &[GenericArgumentId<'db>],
309        gargs1: &[GenericArgumentId<'db>],
310    ) -> InferenceResult<Vec<GenericArgumentId<'db>>> {
311        zip_eq(gargs0, gargs1)
312            .map(|(garg0, garg1)| self.conform_generic_arg(*garg0, *garg1))
313            .collect::<Result<Vec<_>, _>>()
314    }
315
316    /// Conforms a generic arg. See `conform_ty()`.
317    fn conform_generic_arg(
318        &mut self,
319        garg0: GenericArgumentId<'db>,
320        garg1: GenericArgumentId<'db>,
321    ) -> InferenceResult<GenericArgumentId<'db>> {
322        if garg0 == garg1 {
323            return Ok(garg0);
324        }
325        match garg0 {
326            GenericArgumentId::Type(gty0) => {
327                let GenericArgumentId::Type(gty1) = garg1 else {
328                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
329                };
330                Ok(GenericArgumentId::Type(self.conform_ty(gty0, gty1)?))
331            }
332            GenericArgumentId::Constant(gc0) => {
333                let GenericArgumentId::Constant(gc1) = garg1 else {
334                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
335                };
336
337                Ok(GenericArgumentId::Constant(self.conform_const(gc0, gc1)?))
338            }
339            GenericArgumentId::Impl(impl0) => {
340                let GenericArgumentId::Impl(impl1) = garg1 else {
341                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
342                };
343                Ok(GenericArgumentId::Impl(self.conform_impl(impl0, impl1)?))
344            }
345            GenericArgumentId::NegImpl(neg_impl0) => {
346                let GenericArgumentId::NegImpl(neg_impl1) = garg1 else {
347                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
348                };
349                Ok(GenericArgumentId::NegImpl(self.conform_neg_impl(neg_impl0, neg_impl1)?))
350            }
351        }
352    }
353
354    /// Conforms an impl. See `conform_ty()`.
355    fn conform_impl(
356        &mut self,
357        impl0: ImplId<'db>,
358        impl1: ImplId<'db>,
359    ) -> InferenceResult<ImplId<'db>> {
360        let impl0 = self.rewrite(impl0).no_err();
361        let impl1 = self.rewrite(impl1).no_err();
362        let long_impl1 = impl1.long(self.db);
363        if impl0 == impl1 {
364            return Ok(impl0);
365        }
366        if let ImplLongId::ImplVar(var) = long_impl1 {
367            let impl_concrete_trait = self
368                .db
369                .impl_concrete_trait(impl0)
370                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
371            self.conform_traits(var.long(self.db).concrete_trait_id, impl_concrete_trait)?;
372            let impl_id = self.rewrite(impl0).no_err();
373            return self.assign_impl(*var, impl_id);
374        }
375        match impl0.long(self.db) {
376            ImplLongId::ImplVar(var) => {
377                let impl_concrete_trait = self
378                    .db
379                    .impl_concrete_trait(impl1)
380                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
381                self.conform_traits(var.long(self.db).concrete_trait_id, impl_concrete_trait)?;
382                let impl_id = self.rewrite(impl1).no_err();
383                self.assign_impl(*var, impl_id)
384            }
385            ImplLongId::Concrete(concrete0) => {
386                let ImplLongId::Concrete(concrete1) = long_impl1 else {
387                    return Err(self.set_error(InferenceError::ImplKindMismatch { impl0, impl1 }));
388                };
389                let concrete0 = concrete0.long(self.db);
390                let concrete1 = concrete1.long(self.db);
391                if concrete0.impl_def_id != concrete1.impl_def_id {
392                    return Err(self.set_error(InferenceError::ImplKindMismatch { impl0, impl1 }));
393                }
394                let gargs0 = concrete0.generic_args.clone();
395                let gargs1 = concrete1.generic_args.clone();
396                let generic_args = self.conform_generic_args(&gargs0, &gargs1)?;
397                Ok(ImplLongId::Concrete(
398                    ConcreteImplLongId { impl_def_id: concrete0.impl_def_id, generic_args }
399                        .intern(self.db),
400                )
401                .intern(self.db))
402            }
403            ImplLongId::GenericParameter(_)
404            | ImplLongId::ImplImpl(_)
405            | ImplLongId::SelfImpl(_)
406            | ImplLongId::GeneratedImpl(_) => {
407                Err(self.set_error(InferenceError::ImplKindMismatch { impl0, impl1 }))
408            }
409        }
410    }
411
412    fn conform_neg_impl(
413        &mut self,
414        neg_impl0: NegativeImplId<'db>,
415        neg_impl1: NegativeImplId<'db>,
416    ) -> InferenceResult<NegativeImplId<'db>> {
417        let neg_impl0 = self.rewrite(neg_impl0).no_err();
418        let neg_impl1 = self.rewrite(neg_impl1).no_err();
419        let long_neg_impl1 = neg_impl1.long(self.db);
420        if neg_impl0 == neg_impl1 {
421            return Ok(neg_impl0);
422        }
423        if let NegativeImplLongId::NegativeImplVar(var) = long_neg_impl1 {
424            let impl_concrete_trait = neg_impl0
425                .concrete_trait(self.db)
426                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
427            self.conform_traits(var.long(self.db).concrete_trait_id, impl_concrete_trait)?;
428            let neg_impl_id = self.rewrite(neg_impl0).no_err();
429            return self.assign_neg_impl(*var, neg_impl_id);
430        }
431        if let NegativeImplLongId::NegativeImplVar(var) = neg_impl0.long(self.db) {
432            let impl_concrete_trait = neg_impl1
433                .concrete_trait(self.db)
434                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
435            self.conform_traits(var.long(self.db).concrete_trait_id, impl_concrete_trait)?;
436            let neg_impl_id = self.rewrite(neg_impl1).no_err();
437            return self.assign_neg_impl(*var, neg_impl_id);
438        }
439        // we do not care about multiple negative impls, so we do not check they are of the same
440        // variant, we just conform the traits.
441        let trt0 = neg_impl0
442            .concrete_trait(self.db)
443            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
444        let trt1 = neg_impl1
445            .concrete_trait(self.db)
446            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
447        self.conform_traits(trt0, trt1)?;
448        Ok(self.rewrite(neg_impl1).no_err())
449    }
450
451    /// Conforms generic traits. See `conform_ty()`.
452    fn conform_traits(
453        &mut self,
454        trt0: ConcreteTraitId<'db>,
455        trt1: ConcreteTraitId<'db>,
456    ) -> InferenceResult<ConcreteTraitId<'db>> {
457        let trt0 = trt0.long(self.db);
458        let trt1 = trt1.long(self.db);
459        if trt0.trait_id != trt1.trait_id {
460            return Err(self.set_error(InferenceError::TraitMismatch {
461                trt0: trt0.trait_id,
462                trt1: trt1.trait_id,
463            }));
464        }
465        let generic_args = self.conform_generic_args(&trt0.generic_args, &trt1.generic_args)?;
466        Ok(ConcreteTraitLongId { trait_id: trt0.trait_id, generic_args }.intern(self.db))
467    }
468
469    fn conform_generic_function(
470        &mut self,
471        func0: GenericFunctionId<'db>,
472        func1: GenericFunctionId<'db>,
473    ) -> InferenceResult<GenericFunctionId<'db>> {
474        if let (GenericFunctionId::Impl(id0), GenericFunctionId::Impl(id1)) = (func0, func1) {
475            if id0.function != id1.function {
476                return Err(
477                    self.set_error(InferenceError::GenericFunctionMismatch { func0, func1 })
478                );
479            }
480            let function = id0.function;
481            let impl_id = self.conform_impl(id0.impl_id, id1.impl_id)?;
482            return Ok(GenericFunctionId::Impl(ImplGenericFunctionId { impl_id, function }));
483        }
484
485        if func0 != func1 {
486            return Err(self.set_error(InferenceError::GenericFunctionMismatch { func0, func1 }));
487        }
488        Ok(func0)
489    }
490
491    /// Checks if a type tree contains a certain [InferenceVar] somewhere. Used to avoid inference
492    /// cycles.
493    fn ty_contains_var(&mut self, ty: TypeId<'db>, var: InferenceVar) -> bool {
494        let ty = self.rewrite(ty).no_err();
495        self.internal_ty_contains_var(ty, var)
496    }
497
498    /// Checks if a slice of generic arguments contain a certain [InferenceVar] somewhere. Used to
499    /// avoid inference cycles.
500    fn generic_args_contain_var(
501        &mut self,
502        generic_args: &[GenericArgumentId<'db>],
503        var: InferenceVar,
504    ) -> bool {
505        for garg in generic_args {
506            if match *garg {
507                GenericArgumentId::Type(ty) => self.internal_ty_contains_var(ty, var),
508                GenericArgumentId::Constant(_) => false,
509                GenericArgumentId::Impl(impl_id) => self.impl_contains_var(impl_id, var),
510                GenericArgumentId::NegImpl(neg_impl_id) => {
511                    self.negative_impl_contains_var(neg_impl_id, var)
512                }
513            } {
514                return true;
515            }
516        }
517        false
518    }
519
520    /// Checks if an impl contains a certain [InferenceVar] somewhere. Used to avoid inference
521    /// cycles.
522    fn impl_contains_var(&mut self, impl_id: ImplId<'db>, var: InferenceVar) -> bool {
523        match impl_id.long(self.db) {
524            ImplLongId::Concrete(concrete_impl_id) => {
525                self.generic_args_contain_var(&concrete_impl_id.long(self.db).generic_args, var)
526            }
527            ImplLongId::SelfImpl(concrete_trait_id) => {
528                self.generic_args_contain_var(concrete_trait_id.generic_args(self.db), var)
529            }
530            ImplLongId::GenericParameter(_) => false,
531            ImplLongId::ImplVar(new_var) => {
532                let new_var_long_id = new_var.long(self.db);
533                let new_var_local_id = new_var_long_id.id;
534                if InferenceVar::Impl(new_var_local_id) == var {
535                    return true;
536                }
537                if let Some(impl_id) = self.impl_assignment(new_var_local_id) {
538                    return self.impl_contains_var(impl_id, var);
539                }
540                self.generic_args_contain_var(
541                    new_var_long_id.concrete_trait_id.generic_args(self.db),
542                    var,
543                )
544            }
545            ImplLongId::ImplImpl(impl_impl) => self.impl_contains_var(impl_impl.impl_id(), var),
546            ImplLongId::GeneratedImpl(generated_impl) => self.generic_args_contain_var(
547                generated_impl.concrete_trait(self.db).generic_args(self.db),
548                var,
549            ),
550        }
551    }
552
553    /// Checks if a negative impl contains a certain [InferenceVar] somewhere. Used to avoid
554    /// inference cycles.
555    fn negative_impl_contains_var(
556        &mut self,
557        neg_impl_id: NegativeImplId<'db>,
558        var: InferenceVar,
559    ) -> bool {
560        match neg_impl_id.long(self.db) {
561            NegativeImplLongId::Solved(concrete_trait_id) => {
562                self.generic_args_contain_var(&concrete_trait_id.long(self.db).generic_args, var)
563            }
564            NegativeImplLongId::GenericParameter(_) => false,
565            NegativeImplLongId::NegativeImplVar(new_var) => {
566                let new_var_long_id = new_var.long(self.db);
567                let new_var_local_id = new_var_long_id.id;
568                if InferenceVar::NegativeImpl(new_var_local_id) == var {
569                    return true;
570                }
571                if let Some(neg_impl_id) = self.negative_impl_assignment(new_var_local_id) {
572                    return self.negative_impl_contains_var(neg_impl_id, var);
573                }
574                self.generic_args_contain_var(
575                    new_var_long_id.concrete_trait_id.generic_args(self.db),
576                    var,
577                )
578            }
579        }
580    }
581
582    /// Checks if a function contains a certain [InferenceVar] in its generic arguments or in the
583    /// generic arguments of the impl containing the function (in case the function is an impl
584    /// function).
585    ///
586    /// Used to avoid inference cycles.
587    fn function_contains_var(&mut self, function_id: FunctionId<'db>, var: InferenceVar) -> bool {
588        let function = function_id.get_concrete(self.db);
589        let generic_args = function.generic_args;
590        // Look in the generic arguments of the function and in the impl generic arguments.
591        self.generic_args_contain_var(&generic_args, var)
592            || matches!(function.generic_function,
593                GenericFunctionId::Impl(impl_generic_function_id)
594                if self.impl_contains_var(impl_generic_function_id.impl_id, var)
595            )
596    }
597}
598
599impl<'db> Inference<'db, '_> {
600    /// Reduces an impl type to a concrete type.
601    pub fn reduce_impl_ty(
602        &mut self,
603        impl_type_id: ImplTypeId<'db>,
604    ) -> InferenceResult<TypeId<'db>> {
605        let impl_id = impl_type_id.impl_id();
606        let trait_ty = impl_type_id.ty();
607        if let ImplLongId::ImplVar(var) = impl_id.long(self.db) {
608            Ok(self.rewritten_impl_type(*var, trait_ty))
609        } else if let Ok(ty) =
610            self.db.impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_ty, self.db))
611        {
612            Ok(ty)
613        } else {
614            Err(self.set_impl_reduction_error(impl_id))
615        }
616    }
617
618    /// Reduces an impl constant to a concrete const.
619    pub fn reduce_impl_constant(
620        &mut self,
621        impl_const_id: ImplConstantId<'db>,
622    ) -> InferenceResult<ConstValueId<'db>> {
623        let impl_id = impl_const_id.impl_id();
624        let trait_constant = impl_const_id.trait_constant_id();
625        if let ImplLongId::ImplVar(var) = impl_id.long(self.db) {
626            Ok(self.rewritten_impl_constant(*var, trait_constant))
627        } else if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
628            ImplConstantId::new(impl_id, trait_constant, self.db),
629        ) {
630            Ok(constant)
631        } else {
632            Err(self.set_impl_reduction_error(impl_id))
633        }
634    }
635
636    /// Reduces an impl impl to a concrete impl.
637    pub fn reduce_impl_impl(
638        &mut self,
639        impl_impl_id: ImplImplId<'db>,
640    ) -> InferenceResult<ImplId<'db>> {
641        let impl_id = impl_impl_id.impl_id();
642        let concrete_trait_impl = impl_impl_id
643            .concrete_trait_impl_id(self.db)
644            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
645
646        if let ImplLongId::ImplVar(var) = impl_id.long(self.db) {
647            Ok(self.rewritten_impl_impl(*var, concrete_trait_impl))
648        } else if let Ok(imp) = self.db.impl_impl_concrete_implized(ImplImplId::new(
649            impl_id,
650            impl_impl_id.trait_impl_id(),
651            self.db,
652        )) {
653            Ok(imp)
654        } else {
655            Err(self.set_impl_reduction_error(impl_id))
656        }
657    }
658
659    /// Returns the type of an impl var's type item.
660    /// The type may be a variable itself, but it may previously exist, so may be more specific due
661    /// to rewriting.
662    pub fn rewritten_impl_type(
663        &mut self,
664        id: ImplVarId<'db>,
665        trait_type_id: TraitTypeId<'db>,
666    ) -> TypeId<'db> {
667        self.rewritten_impl_item(
668            id,
669            trait_type_id,
670            |m| &mut m.types,
671            |inference, stable_ptr| inference.new_type_var(stable_ptr),
672        )
673    }
674
675    /// Returns the constant value of an impl var's constant item.
676    /// The constant may be a variable itself, but it may previously exist, so may be more specific
677    /// due to rewriting.
678    pub fn rewritten_impl_constant(
679        &mut self,
680        id: ImplVarId<'db>,
681        trait_constant: TraitConstantId<'db>,
682    ) -> ConstValueId<'db> {
683        self.rewritten_impl_item(
684            id,
685            trait_constant,
686            |m| &mut m.constants,
687            |inference, stable_ptr| {
688                inference.new_const_var(
689                    stable_ptr,
690                    inference.db.trait_constant_type(trait_constant).unwrap(),
691                )
692            },
693        )
694    }
695
696    /// Returns the inner_impl value of an impl var's impl item.
697    /// The inner_impl may be a variable itself, but it may previously exist, so may be more
698    /// specific due to rewriting.
699    pub fn rewritten_impl_impl(
700        &mut self,
701        id: ImplVarId<'db>,
702        concrete_trait_impl: ConcreteTraitImplId<'db>,
703    ) -> ImplId<'db> {
704        let trait_crate = concrete_trait_impl
705            .trait_impl(self.db)
706            .trait_id(self.db)
707            .parent_module(self.db)
708            .owning_crate(self.db);
709        self.rewritten_impl_item(
710            id,
711            concrete_trait_impl.trait_impl(self.db),
712            |m| &mut m.impls,
713            |inference, stable_ptr| {
714                inference.new_impl_var(
715                    inference.db.concrete_trait_impl_concrete_trait(concrete_trait_impl).unwrap(),
716                    stable_ptr,
717                    ImplLookupContext::new_from_crate(trait_crate).intern(self.db),
718                )
719            },
720        )
721    }
722
723    /// Helper function for getting an impl vars item ids.
724    /// These ids are likely to be variables, but may have more specific information due to
725    /// rewriting.
726    fn rewritten_impl_item<K: Hash + PartialEq + Eq, V: Copy>(
727        &mut self,
728        id: ImplVarId<'db>,
729        key: K,
730        get_map: impl for<'a> Fn(&'a mut ImplVarTraitItemMappings<'db>) -> &'a mut OrderedHashMap<K, V>,
731        new_var: impl FnOnce(&mut Self, Option<SyntaxStablePtrId<'db>>) -> V,
732    ) -> V
733    where
734        Self: SemanticRewriter<V, NoError>,
735    {
736        let var_id = id.id(self.db);
737        if let Some(value) = self
738            .data
739            .impl_vars_trait_item_mappings
740            .get_mut(&var_id)
741            .and_then(|mappings| get_map(mappings).get(&key))
742        {
743            // Copy the value to allow usage of `self`.
744            let value = *value;
745            // If the value already exists, rewrite it before returning.
746            self.rewrite(value).no_err()
747        } else {
748            let value =
749                new_var(self, self.data.stable_ptrs.get(&InferenceVar::Impl(var_id)).cloned());
750            get_map(self.data.impl_vars_trait_item_mappings.entry(var_id).or_default())
751                .insert(key, value);
752            value
753        }
754    }
755
756    /// Sets an error for an impl reduction failure.
757    fn set_impl_reduction_error(&mut self, impl_id: ImplId<'db>) -> ErrorSet {
758        self.set_error(
759            impl_id
760                .concrete_trait(self.db)
761                .map(InferenceError::NoImplsFound)
762                .unwrap_or_else(InferenceError::Reported),
763        )
764    }
765
766    /// Conforms a type to a type. Returning the reduced types on failure.
767    /// Useful for immediately reporting a diagnostic based on the compared types.
768    pub fn conform_ty_for_diag(
769        &mut self,
770        ty0: TypeId<'db>,
771        ty1: TypeId<'db>,
772        diagnostics: &mut SemanticDiagnostics<'db>,
773        diag_stable_ptr: impl FnOnce() -> SyntaxStablePtrId<'db>,
774        diag_kind: impl FnOnce(TypeId<'db>, TypeId<'db>) -> SemanticDiagnosticKind<'db>,
775    ) -> Maybe<()> {
776        match self.conform_ty(ty0, ty1) {
777            Ok(_ty) => Ok(()),
778            Err(err) => {
779                let ty0 = self.rewrite(ty0).no_err();
780                let ty1 = self.rewrite(ty1).no_err();
781                Err(if ty0 != ty1 {
782                    let diag_added = diagnostics.report(diag_stable_ptr(), diag_kind(ty0, ty1));
783                    self.consume_reported_error(err, diag_added);
784                    diag_added
785                } else {
786                    self.report_on_pending_error(err, diagnostics, diag_stable_ptr())
787                })
788            }
789        }
790    }
791
792    /// helper function for ty_contains_var
793    /// Assumes ty was already rewritten.
794    #[doc(hidden)]
795    fn internal_ty_contains_var(&mut self, ty: TypeId<'db>, var: InferenceVar) -> bool {
796        match ty.long(self.db) {
797            TypeLongId::Concrete(concrete) => {
798                let generic_args = concrete.generic_args(self.db);
799                self.generic_args_contain_var(&generic_args, var)
800            }
801            TypeLongId::Tuple(tys) => tys.iter().any(|ty| self.internal_ty_contains_var(*ty, var)),
802            TypeLongId::Snapshot(ty) => self.internal_ty_contains_var(*ty, var),
803            TypeLongId::Var(ty_var) => {
804                if InferenceVar::Type(ty_var.id) == var {
805                    return true;
806                }
807                if let Some(ty) = self.type_assignment.get(&ty_var.id) {
808                    return self.internal_ty_contains_var(*ty, var);
809                }
810                false
811            }
812            TypeLongId::ImplType(id) => self.impl_contains_var(id.impl_id(), var),
813            TypeLongId::GenericParameter(_) | TypeLongId::Missing(_) => false,
814            TypeLongId::Coupon(function_id) => self.function_contains_var(*function_id, var),
815            TypeLongId::FixedSizeArray { type_id, .. } => {
816                self.internal_ty_contains_var(*type_id, var)
817            }
818            TypeLongId::Closure(closure) => {
819                closure.param_tys.iter().any(|ty| self.internal_ty_contains_var(*ty, var))
820                    || self.internal_ty_contains_var(closure.ret_ty, var)
821            }
822        }
823    }
824
825    /// Creates a var for each constrained impl_type and conforms the types.
826    pub fn conform_generic_params_type_constraints(
827        &mut self,
828        constraints: &[(TypeId<'db>, TypeId<'db>)],
829    ) {
830        let mut impl_type_bounds = Default::default();
831        for (ty0, ty1) in constraints {
832            let ty0 = if let TypeLongId::ImplType(impl_type) = ty0.long(self.db) {
833                self.impl_type_assignment(*impl_type, &mut impl_type_bounds)
834            } else {
835                *ty0
836            };
837            let ty1 = if let TypeLongId::ImplType(impl_type) = ty1.long(self.db) {
838                self.impl_type_assignment(*impl_type, &mut impl_type_bounds)
839            } else {
840                *ty1
841            };
842            self.conform_ty(ty0, ty1).ok();
843        }
844        self.set_impl_type_bounds(impl_type_bounds);
845    }
846
847    /// A helper function for getting for an impl type assignment.
848    /// Creates a new type var if the impl type is not yet assigned.
849    fn impl_type_assignment(
850        &mut self,
851        impl_type: ImplTypeId<'db>,
852        impl_type_bounds: &mut OrderedHashMap<ImplTypeId<'db>, TypeId<'db>>,
853    ) -> TypeId<'db> {
854        match impl_type_bounds.entry(impl_type) {
855            Entry::Occupied(entry) => *entry.get(),
856            Entry::Vacant(entry) => {
857                let inference_id = self.data.inference_id;
858                let id = LocalTypeVarId(self.data.type_vars.len());
859                let var = TypeVar { inference_id, id };
860                let ty = TypeLongId::Var(var).intern(self.db);
861                entry.insert(ty);
862                self.type_vars.push(var);
863                ty
864            }
865        }
866    }
867}