cairo_lang_semantic/expr/inference/
conform.rs

1use std::hash::Hash;
2
3use cairo_lang_defs::ids::{LanguageElementId, TraitConstantId, TraitTypeId};
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
6use cairo_lang_utils::Intern;
7use cairo_lang_utils::ordered_hash_map::{Entry, OrderedHashMap};
8use itertools::zip_eq;
9
10use super::canonic::{NoError, ResultNoErrEx};
11use super::{
12    ErrorSet, ImplVarId, ImplVarTraitItemMappings, Inference, InferenceError, InferenceResult,
13    InferenceVar, LocalTypeVarId, TypeVar,
14};
15use crate::corelib::never_ty;
16use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
17use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
18use crate::items::functions::{GenericFunctionId, ImplGenericFunctionId};
19use crate::items::imp::{
20    ImplId, ImplImplId, ImplLongId, ImplLookupContext, ImplSemantic, NegativeImplId,
21    NegativeImplLongId,
22};
23use crate::items::trt::{ConcreteTraitImplId, TraitSemantic};
24use crate::substitution::SemanticRewriter;
25use crate::types::{ClosureTypeLongId, ImplTypeId, peel_snapshots};
26use crate::{
27    ConcreteFunction, ConcreteImplLongId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId,
28    FunctionId, FunctionLongId, GenericArgumentId, TypeId, TypeLongId,
29};
30
31/// Functions for conforming semantic objects with each other.
32pub trait InferenceConform<'db> {
33    fn conform_ty(&mut self, ty0: TypeId<'db>, ty1: TypeId<'db>) -> InferenceResult<TypeId<'db>>;
34    fn conform_ty_ex(
35        &mut self,
36        ty0: TypeId<'db>,
37        ty1: TypeId<'db>,
38        ty0_is_self: bool,
39    ) -> InferenceResult<(TypeId<'db>, usize)>;
40    fn conform_const(
41        &mut self,
42        ty0: ConstValueId<'db>,
43        ty1: ConstValueId<'db>,
44    ) -> InferenceResult<ConstValueId<'db>>;
45    fn maybe_peel_snapshots(
46        &mut self,
47        ty0_is_self: bool,
48        ty1: TypeId<'db>,
49    ) -> (usize, TypeLongId<'db>);
50    fn conform_generic_args(
51        &mut self,
52        gargs0: &[GenericArgumentId<'db>],
53        gargs1: &[GenericArgumentId<'db>],
54    ) -> InferenceResult<Vec<GenericArgumentId<'db>>>;
55    fn conform_generic_arg(
56        &mut self,
57        garg0: GenericArgumentId<'db>,
58        garg1: GenericArgumentId<'db>,
59    ) -> InferenceResult<GenericArgumentId<'db>>;
60    fn conform_impl(
61        &mut self,
62        impl0: ImplId<'db>,
63        impl1: ImplId<'db>,
64    ) -> InferenceResult<ImplId<'db>>;
65    fn conform_neg_impl(
66        &mut self,
67        neg_impl0: NegativeImplId<'db>,
68        neg_impl1: NegativeImplId<'db>,
69    ) -> InferenceResult<NegativeImplId<'db>>;
70    fn conform_traits(
71        &mut self,
72        trt0: ConcreteTraitId<'db>,
73        trt1: ConcreteTraitId<'db>,
74    ) -> InferenceResult<ConcreteTraitId<'db>>;
75    fn conform_generic_function(
76        &mut self,
77        trt0: GenericFunctionId<'db>,
78        trt1: GenericFunctionId<'db>,
79    ) -> InferenceResult<GenericFunctionId<'db>>;
80    fn ty_contains_var(&mut self, ty: TypeId<'db>, var: InferenceVar) -> bool;
81    fn generic_args_contain_var(
82        &mut self,
83        generic_args: &[GenericArgumentId<'db>],
84        var: InferenceVar,
85    ) -> bool;
86    fn impl_contains_var(&mut self, impl_id: ImplId<'db>, var: InferenceVar) -> bool;
87    fn negative_impl_contains_var(
88        &mut self,
89        neg_impl_id: NegativeImplId<'db>,
90        var: InferenceVar,
91    ) -> bool;
92    fn function_contains_var(&mut self, function_id: FunctionId<'db>, var: InferenceVar) -> bool;
93}
94
95impl<'db> InferenceConform<'db> for Inference<'db, '_> {
96    /// Conforms ty0 to ty1. Should be called when ty0 should be coerced to ty1. Not symmetric.
97    /// Returns the reduced type for ty0, or an error if the type is no coercible.
98    fn conform_ty(&mut self, ty0: TypeId<'db>, ty1: TypeId<'db>) -> InferenceResult<TypeId<'db>> {
99        Ok(self.conform_ty_ex(ty0, ty1, false)?.0)
100    }
101
102    /// Same as conform_ty but supports adding snapshots to ty0 if `ty0_is_self` is true.
103    /// Returns the reduced type for ty0 and the number of snapshots that needs to be added
104    /// for the types to conform.
105    fn conform_ty_ex(
106        &mut self,
107        ty0: TypeId<'db>,
108        ty1: TypeId<'db>,
109        ty0_is_self: bool,
110    ) -> InferenceResult<(TypeId<'db>, usize)> {
111        let ty0 = self.rewrite(ty0).no_err();
112        let ty1 = self.rewrite(ty1).no_err();
113        if ty0 == never_ty(self.db) || ty0.is_missing(self.db) {
114            return Ok((ty1, 0));
115        }
116        if ty0 == ty1 {
117            return Ok((ty0, 0));
118        }
119        let long_ty1 = ty1.long(self.db);
120        match long_ty1 {
121            TypeLongId::Var(var) => return Ok((self.assign_ty(*var, ty0)?, 0)),
122            TypeLongId::Missing(_) => return Ok((ty1, 0)),
123            TypeLongId::Snapshot(inner_ty) => {
124                if ty0_is_self {
125                    if *inner_ty == ty0 {
126                        return Ok((ty1, 1));
127                    }
128                    if !matches!(ty0.long(self.db), TypeLongId::Snapshot(_))
129                        && let TypeLongId::Var(var) = inner_ty.long(self.db)
130                    {
131                        return Ok((self.assign_ty(*var, ty0)?, 1));
132                    }
133                }
134            }
135            TypeLongId::ImplType(impl_type) => {
136                if let Some(ty) = self.impl_type_bounds.get(&(*impl_type).into()) {
137                    return self.conform_ty_ex(ty0, *ty, ty0_is_self);
138                }
139            }
140            _ => {}
141        }
142        let long_ty0 = ty0.long(self.db);
143
144        match long_ty0 {
145            TypeLongId::Concrete(concrete0) => {
146                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
147                let TypeLongId::Concrete(concrete1) = long_ty1 else {
148                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
149                };
150                if concrete0.generic_type(self.db) != concrete1.generic_type(self.db) {
151                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
152                }
153                let gargs0 = concrete0.generic_args(self.db);
154                let gargs1 = concrete1.generic_args(self.db);
155                let gargs = self.conform_generic_args(&gargs0, &gargs1)?;
156                let long_ty = TypeLongId::Concrete(ConcreteTypeId::new(
157                    self.db,
158                    concrete0.generic_type(self.db),
159                    gargs,
160                ));
161                Ok((long_ty.intern(self.db), n_snapshots))
162            }
163            TypeLongId::Tuple(tys0) => {
164                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
165                let TypeLongId::Tuple(tys1) = long_ty1 else {
166                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
167                };
168                if tys0.len() != tys1.len() {
169                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
170                }
171                let tys = zip_eq(tys0, tys1)
172                    .map(|(subty0, subty1)| self.conform_ty(*subty0, subty1))
173                    .collect::<Result<Vec<_>, _>>()?;
174                Ok((TypeLongId::Tuple(tys).intern(self.db), n_snapshots))
175            }
176            TypeLongId::Closure(closure0) => {
177                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
178                let TypeLongId::Closure(closure1) = long_ty1 else {
179                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
180                };
181                if closure0.wrapper_location != closure1.wrapper_location {
182                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
183                }
184                let param_tys = zip_eq(closure0.param_tys.clone(), closure1.param_tys)
185                    .map(|(subty0, subty1)| self.conform_ty(subty0, subty1))
186                    .collect::<Result<Vec<_>, _>>()?;
187                let captured_types =
188                    zip_eq(closure0.captured_types.clone(), closure1.captured_types)
189                        .map(|(subty0, subty1)| self.conform_ty(subty0, subty1))
190                        .collect::<Result<Vec<_>, _>>()?;
191                let ret_ty = self.conform_ty(closure0.ret_ty, closure1.ret_ty)?;
192                Ok((
193                    TypeLongId::Closure(ClosureTypeLongId {
194                        param_tys,
195                        ret_ty,
196                        captured_types,
197                        wrapper_location: closure0.wrapper_location,
198                        parent_function: closure0.parent_function,
199                    })
200                    .intern(self.db),
201                    n_snapshots,
202                ))
203            }
204            TypeLongId::FixedSizeArray { type_id, size } => {
205                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
206                let TypeLongId::FixedSizeArray { type_id: type_id1, size: size1 } = long_ty1 else {
207                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
208                };
209                let size = self.conform_const(*size, size1)?;
210                let ty = self.conform_ty(*type_id, type_id1)?;
211                Ok((TypeLongId::FixedSizeArray { type_id: ty, size }.intern(self.db), n_snapshots))
212            }
213            TypeLongId::Snapshot(inner_ty0) => {
214                let TypeLongId::Snapshot(inner_ty1) = long_ty1 else {
215                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
216                };
217                let (ty, n_snapshots) = self.conform_ty_ex(*inner_ty0, *inner_ty1, ty0_is_self)?;
218                Ok((TypeLongId::Snapshot(ty).intern(self.db), n_snapshots))
219            }
220            TypeLongId::GenericParameter(_) => {
221                Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }))
222            }
223            TypeLongId::Var(var) => Ok((self.assign_ty(*var, ty1)?, 0)),
224            TypeLongId::ImplType(impl_type) => {
225                if let Some(ty) = self.impl_type_bounds.get(&(*impl_type).into()) {
226                    return self.conform_ty_ex(*ty, ty1, ty0_is_self);
227                }
228                Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }))
229            }
230            TypeLongId::Missing(_) => Ok((ty0, 0)),
231            TypeLongId::Coupon(function_id0) => {
232                let TypeLongId::Coupon(function_id1) = long_ty1 else {
233                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
234                };
235
236                let func0 = function_id0.long(self.db).function.clone();
237                let func1 = function_id1.long(self.db).function.clone();
238
239                let generic_function =
240                    self.conform_generic_function(func0.generic_function, func1.generic_function)?;
241
242                if func0.generic_args.len() != func1.generic_args.len() {
243                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
244                }
245
246                let generic_args =
247                    self.conform_generic_args(&func0.generic_args, &func1.generic_args)?;
248
249                Ok((
250                    TypeLongId::Coupon(
251                        FunctionLongId {
252                            function: ConcreteFunction { generic_function, generic_args },
253                        }
254                        .intern(self.db),
255                    )
256                    .intern(self.db),
257                    0,
258                ))
259            }
260        }
261    }
262
263    /// Conforms id0 to id1. Should be called when id0 should be coerced to id1. Not symmetric.
264    /// Returns the reduced const for id0, or an error if the const is no coercible.
265    fn conform_const(
266        &mut self,
267        id0: ConstValueId<'db>,
268        id1: ConstValueId<'db>,
269    ) -> InferenceResult<ConstValueId<'db>> {
270        let id0 = self.rewrite(id0).no_err();
271        let id1 = self.rewrite(id1).no_err();
272        self.conform_ty(id0.ty(self.db).unwrap(), id1.ty(self.db).unwrap())?;
273        if id0 == id1 {
274            return Ok(id0);
275        }
276        let const_value0 = id0.long(self.db);
277        if matches!(const_value0, ConstValue::Missing(_)) {
278            return Ok(id1);
279        }
280        match id1.long(self.db) {
281            ConstValue::Missing(_) => return Ok(id1),
282            ConstValue::Var(var, _) => return self.assign_const(*var, id0),
283            _ => {}
284        }
285        match const_value0 {
286            ConstValue::Var(var, _) => Ok(self.assign_const(*var, id1)?),
287            ConstValue::ImplConstant(_) => {
288                Err(self.set_error(InferenceError::ConstKindMismatch { const0: id0, const1: id1 }))
289            }
290            _ => {
291                Err(self.set_error(InferenceError::ConstKindMismatch { const0: id0, const1: id1 }))
292            }
293        }
294    }
295
296    // Conditionally peels snapshots.
297    fn maybe_peel_snapshots(
298        &mut self,
299        ty0_is_self: bool,
300        ty1: TypeId<'db>,
301    ) -> (usize, TypeLongId<'db>) {
302        let (n_snapshots, long_ty1) =
303            if ty0_is_self { peel_snapshots(self.db, ty1) } else { (0, ty1.long(self.db).clone()) };
304        (n_snapshots, long_ty1)
305    }
306
307    /// Conforms generic args. See `conform_ty()`.
308    fn conform_generic_args(
309        &mut self,
310        gargs0: &[GenericArgumentId<'db>],
311        gargs1: &[GenericArgumentId<'db>],
312    ) -> InferenceResult<Vec<GenericArgumentId<'db>>> {
313        zip_eq(gargs0, gargs1)
314            .map(|(garg0, garg1)| self.conform_generic_arg(*garg0, *garg1))
315            .collect::<Result<Vec<_>, _>>()
316    }
317
318    /// Conforms a generic arg. See `conform_ty()`.
319    fn conform_generic_arg(
320        &mut self,
321        garg0: GenericArgumentId<'db>,
322        garg1: GenericArgumentId<'db>,
323    ) -> InferenceResult<GenericArgumentId<'db>> {
324        if garg0 == garg1 {
325            return Ok(garg0);
326        }
327        match garg0 {
328            GenericArgumentId::Type(gty0) => {
329                let GenericArgumentId::Type(gty1) = garg1 else {
330                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
331                };
332                Ok(GenericArgumentId::Type(self.conform_ty(gty0, gty1)?))
333            }
334            GenericArgumentId::Constant(gc0) => {
335                let GenericArgumentId::Constant(gc1) = garg1 else {
336                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
337                };
338
339                Ok(GenericArgumentId::Constant(self.conform_const(gc0, gc1)?))
340            }
341            GenericArgumentId::Impl(impl0) => {
342                let GenericArgumentId::Impl(impl1) = garg1 else {
343                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
344                };
345                Ok(GenericArgumentId::Impl(self.conform_impl(impl0, impl1)?))
346            }
347            GenericArgumentId::NegImpl(neg_impl0) => {
348                let GenericArgumentId::NegImpl(neg_impl1) = garg1 else {
349                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
350                };
351                Ok(GenericArgumentId::NegImpl(self.conform_neg_impl(neg_impl0, neg_impl1)?))
352            }
353        }
354    }
355
356    /// Conforms an impl. See `conform_ty()`.
357    fn conform_impl(
358        &mut self,
359        impl0: ImplId<'db>,
360        impl1: ImplId<'db>,
361    ) -> InferenceResult<ImplId<'db>> {
362        let impl0 = self.rewrite(impl0).no_err();
363        let impl1 = self.rewrite(impl1).no_err();
364        let long_impl1 = impl1.long(self.db);
365        if impl0 == impl1 {
366            return Ok(impl0);
367        }
368        if let ImplLongId::ImplVar(var) = long_impl1 {
369            let impl_concrete_trait = self
370                .db
371                .impl_concrete_trait(impl0)
372                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
373            self.conform_traits(var.long(self.db).concrete_trait_id, impl_concrete_trait)?;
374            let impl_id = self.rewrite(impl0).no_err();
375            return self.assign_impl(*var, impl_id);
376        }
377        match impl0.long(self.db) {
378            ImplLongId::ImplVar(var) => {
379                let impl_concrete_trait = self
380                    .db
381                    .impl_concrete_trait(impl1)
382                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
383                self.conform_traits(var.long(self.db).concrete_trait_id, impl_concrete_trait)?;
384                let impl_id = self.rewrite(impl1).no_err();
385                self.assign_impl(*var, impl_id)
386            }
387            ImplLongId::Concrete(concrete0) => {
388                let ImplLongId::Concrete(concrete1) = long_impl1 else {
389                    return Err(self.set_error(InferenceError::ImplKindMismatch { impl0, impl1 }));
390                };
391                let concrete0 = concrete0.long(self.db);
392                let concrete1 = concrete1.long(self.db);
393                if concrete0.impl_def_id != concrete1.impl_def_id {
394                    return Err(self.set_error(InferenceError::ImplKindMismatch { impl0, impl1 }));
395                }
396                let gargs0 = concrete0.generic_args.clone();
397                let gargs1 = concrete1.generic_args.clone();
398                let generic_args = self.conform_generic_args(&gargs0, &gargs1)?;
399                Ok(ImplLongId::Concrete(
400                    ConcreteImplLongId { impl_def_id: concrete0.impl_def_id, generic_args }
401                        .intern(self.db),
402                )
403                .intern(self.db))
404            }
405            ImplLongId::GenericParameter(_)
406            | ImplLongId::ImplImpl(_)
407            | ImplLongId::SelfImpl(_)
408            | ImplLongId::GeneratedImpl(_) => {
409                Err(self.set_error(InferenceError::ImplKindMismatch { impl0, impl1 }))
410            }
411        }
412    }
413
414    fn conform_neg_impl(
415        &mut self,
416        neg_impl0: NegativeImplId<'db>,
417        neg_impl1: NegativeImplId<'db>,
418    ) -> InferenceResult<NegativeImplId<'db>> {
419        let neg_impl0 = self.rewrite(neg_impl0).no_err();
420        let neg_impl1 = self.rewrite(neg_impl1).no_err();
421        let long_neg_impl1 = neg_impl1.long(self.db);
422        if neg_impl0 == neg_impl1 {
423            return Ok(neg_impl0);
424        }
425        if let NegativeImplLongId::NegativeImplVar(var) = long_neg_impl1 {
426            let impl_concrete_trait = neg_impl0
427                .concrete_trait(self.db)
428                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
429            self.conform_traits(var.long(self.db).concrete_trait_id, impl_concrete_trait)?;
430            let neg_impl_id = self.rewrite(neg_impl0).no_err();
431            return self.assign_neg_impl(*var, neg_impl_id);
432        }
433        if let NegativeImplLongId::NegativeImplVar(var) = neg_impl0.long(self.db) {
434            let impl_concrete_trait = neg_impl1
435                .concrete_trait(self.db)
436                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
437            self.conform_traits(var.long(self.db).concrete_trait_id, impl_concrete_trait)?;
438            let neg_impl_id = self.rewrite(neg_impl1).no_err();
439            return self.assign_neg_impl(*var, neg_impl_id);
440        }
441        // we do not care about multiple negative impls, so we do not check they are of the same
442        // variant, we just conform the traits.
443        let trt0 = neg_impl0
444            .concrete_trait(self.db)
445            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
446        let trt1 = neg_impl1
447            .concrete_trait(self.db)
448            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
449        self.conform_traits(trt0, trt1)?;
450        Ok(self.rewrite(neg_impl1).no_err())
451    }
452
453    /// Conforms generic traits. See `conform_ty()`.
454    fn conform_traits(
455        &mut self,
456        trt0: ConcreteTraitId<'db>,
457        trt1: ConcreteTraitId<'db>,
458    ) -> InferenceResult<ConcreteTraitId<'db>> {
459        let trt0 = trt0.long(self.db);
460        let trt1 = trt1.long(self.db);
461        if trt0.trait_id != trt1.trait_id {
462            return Err(self.set_error(InferenceError::TraitMismatch {
463                trt0: trt0.trait_id,
464                trt1: trt1.trait_id,
465            }));
466        }
467        let generic_args = self.conform_generic_args(&trt0.generic_args, &trt1.generic_args)?;
468        Ok(ConcreteTraitLongId { trait_id: trt0.trait_id, generic_args }.intern(self.db))
469    }
470
471    fn conform_generic_function(
472        &mut self,
473        func0: GenericFunctionId<'db>,
474        func1: GenericFunctionId<'db>,
475    ) -> InferenceResult<GenericFunctionId<'db>> {
476        if let (GenericFunctionId::Impl(id0), GenericFunctionId::Impl(id1)) = (func0, func1) {
477            if id0.function != id1.function {
478                return Err(
479                    self.set_error(InferenceError::GenericFunctionMismatch { func0, func1 })
480                );
481            }
482            let function = id0.function;
483            let impl_id = self.conform_impl(id0.impl_id, id1.impl_id)?;
484            return Ok(GenericFunctionId::Impl(ImplGenericFunctionId { impl_id, function }));
485        }
486
487        if func0 != func1 {
488            return Err(self.set_error(InferenceError::GenericFunctionMismatch { func0, func1 }));
489        }
490        Ok(func0)
491    }
492
493    /// Checks if a type tree contains a certain [InferenceVar] somewhere. Used to avoid inference
494    /// cycles.
495    fn ty_contains_var(&mut self, ty: TypeId<'db>, var: InferenceVar) -> bool {
496        let ty = self.rewrite(ty).no_err();
497        self.internal_ty_contains_var(ty, var)
498    }
499
500    /// Checks if a slice of generic arguments contain a certain [InferenceVar] somewhere. Used to
501    /// avoid inference cycles.
502    fn generic_args_contain_var(
503        &mut self,
504        generic_args: &[GenericArgumentId<'db>],
505        var: InferenceVar,
506    ) -> bool {
507        for garg in generic_args {
508            if match *garg {
509                GenericArgumentId::Type(ty) => self.internal_ty_contains_var(ty, var),
510                GenericArgumentId::Constant(_) => false,
511                GenericArgumentId::Impl(impl_id) => self.impl_contains_var(impl_id, var),
512                GenericArgumentId::NegImpl(neg_impl_id) => {
513                    self.negative_impl_contains_var(neg_impl_id, var)
514                }
515            } {
516                return true;
517            }
518        }
519        false
520    }
521
522    /// Checks if an impl contains a certain [InferenceVar] somewhere. Used to avoid inference
523    /// cycles.
524    fn impl_contains_var(&mut self, impl_id: ImplId<'db>, var: InferenceVar) -> bool {
525        match impl_id.long(self.db) {
526            ImplLongId::Concrete(concrete_impl_id) => {
527                self.generic_args_contain_var(&concrete_impl_id.long(self.db).generic_args, var)
528            }
529            ImplLongId::SelfImpl(concrete_trait_id) => {
530                self.generic_args_contain_var(concrete_trait_id.generic_args(self.db), var)
531            }
532            ImplLongId::GenericParameter(_) => false,
533            ImplLongId::ImplVar(new_var) => {
534                let new_var_long_id = new_var.long(self.db);
535                let new_var_local_id = new_var_long_id.id;
536                if InferenceVar::Impl(new_var_local_id) == var {
537                    return true;
538                }
539                if let Some(impl_id) = self.impl_assignment(new_var_local_id) {
540                    return self.impl_contains_var(impl_id, var);
541                }
542                self.generic_args_contain_var(
543                    new_var_long_id.concrete_trait_id.generic_args(self.db),
544                    var,
545                )
546            }
547            ImplLongId::ImplImpl(impl_impl) => self.impl_contains_var(impl_impl.impl_id(), var),
548            ImplLongId::GeneratedImpl(generated_impl) => self.generic_args_contain_var(
549                generated_impl.concrete_trait(self.db).generic_args(self.db),
550                var,
551            ),
552        }
553    }
554
555    /// Checks if a negative impl contains a certain [InferenceVar] somewhere. Used to avoid
556    /// inference cycles.
557    fn negative_impl_contains_var(
558        &mut self,
559        neg_impl_id: NegativeImplId<'db>,
560        var: InferenceVar,
561    ) -> bool {
562        match neg_impl_id.long(self.db) {
563            NegativeImplLongId::Solved(concrete_trait_id) => {
564                self.generic_args_contain_var(&concrete_trait_id.long(self.db).generic_args, var)
565            }
566            NegativeImplLongId::GenericParameter(_) => false,
567            NegativeImplLongId::NegativeImplVar(new_var) => {
568                let new_var_long_id = new_var.long(self.db);
569                let new_var_local_id = new_var_long_id.id;
570                if InferenceVar::NegativeImpl(new_var_local_id) == var {
571                    return true;
572                }
573                if let Some(neg_impl_id) = self.negative_impl_assignment(new_var_local_id) {
574                    return self.negative_impl_contains_var(neg_impl_id, var);
575                }
576                self.generic_args_contain_var(
577                    new_var_long_id.concrete_trait_id.generic_args(self.db),
578                    var,
579                )
580            }
581        }
582    }
583
584    /// Checks if a function contains a certain [InferenceVar] in its generic arguments or in the
585    /// generic arguments of the impl containing the function (in case the function is an impl
586    /// function).
587    ///
588    /// Used to avoid inference cycles.
589    fn function_contains_var(&mut self, function_id: FunctionId<'db>, var: InferenceVar) -> bool {
590        let function = function_id.get_concrete(self.db);
591        let generic_args = function.generic_args;
592        // Look in the generic arguments of the function and in the impl generic arguments.
593        self.generic_args_contain_var(&generic_args, var)
594            || matches!(function.generic_function,
595                GenericFunctionId::Impl(impl_generic_function_id)
596                if self.impl_contains_var(impl_generic_function_id.impl_id, var)
597            )
598    }
599}
600
601impl<'db> Inference<'db, '_> {
602    /// Reduces an impl type to a concrete type.
603    pub fn reduce_impl_ty(
604        &mut self,
605        impl_type_id: ImplTypeId<'db>,
606    ) -> InferenceResult<TypeId<'db>> {
607        let impl_id = impl_type_id.impl_id();
608        let trait_ty = impl_type_id.ty();
609        if let ImplLongId::ImplVar(var) = impl_id.long(self.db) {
610            Ok(self.rewritten_impl_type(*var, trait_ty))
611        } else if let Ok(ty) =
612            self.db.impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_ty, self.db))
613        {
614            Ok(ty)
615        } else {
616            Err(self.set_impl_reduction_error(impl_id))
617        }
618    }
619
620    /// Reduces an impl constant to a concrete const.
621    pub fn reduce_impl_constant(
622        &mut self,
623        impl_const_id: ImplConstantId<'db>,
624    ) -> InferenceResult<ConstValueId<'db>> {
625        let impl_id = impl_const_id.impl_id();
626        let trait_constant = impl_const_id.trait_constant_id();
627        if let ImplLongId::ImplVar(var) = impl_id.long(self.db) {
628            Ok(self.rewritten_impl_constant(*var, trait_constant))
629        } else if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
630            ImplConstantId::new(impl_id, trait_constant, self.db),
631        ) {
632            Ok(constant)
633        } else {
634            Err(self.set_impl_reduction_error(impl_id))
635        }
636    }
637
638    /// Reduces an impl impl to a concrete impl.
639    pub fn reduce_impl_impl(
640        &mut self,
641        impl_impl_id: ImplImplId<'db>,
642    ) -> InferenceResult<ImplId<'db>> {
643        let impl_id = impl_impl_id.impl_id();
644        let concrete_trait_impl = impl_impl_id
645            .concrete_trait_impl_id(self.db)
646            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
647
648        if let ImplLongId::ImplVar(var) = impl_id.long(self.db) {
649            Ok(self.rewritten_impl_impl(*var, concrete_trait_impl))
650        } else if let Ok(imp) = self.db.impl_impl_concrete_implized(ImplImplId::new(
651            impl_id,
652            impl_impl_id.trait_impl_id(),
653            self.db,
654        )) {
655            Ok(imp)
656        } else {
657            Err(self.set_impl_reduction_error(impl_id))
658        }
659    }
660
661    /// Returns the type of an impl var's type item.
662    /// The type may be a variable itself, but it may previously exist, so may be more specific due
663    /// to rewriting.
664    pub fn rewritten_impl_type(
665        &mut self,
666        id: ImplVarId<'db>,
667        trait_type_id: TraitTypeId<'db>,
668    ) -> TypeId<'db> {
669        self.rewritten_impl_item(
670            id,
671            trait_type_id,
672            |m| &mut m.types,
673            |inference, stable_ptr| inference.new_type_var(stable_ptr),
674        )
675    }
676
677    /// Returns the constant value of an impl var's constant item.
678    /// The constant may be a variable itself, but it may previously exist, so may be more specific
679    /// due to rewriting.
680    pub fn rewritten_impl_constant(
681        &mut self,
682        id: ImplVarId<'db>,
683        trait_constant: TraitConstantId<'db>,
684    ) -> ConstValueId<'db> {
685        self.rewritten_impl_item(
686            id,
687            trait_constant,
688            |m| &mut m.constants,
689            |inference, stable_ptr| {
690                inference.new_const_var(
691                    stable_ptr,
692                    inference.db.trait_constant_type(trait_constant).unwrap(),
693                )
694            },
695        )
696    }
697
698    /// Returns the inner_impl value of an impl var's impl item.
699    /// The inner_impl may be a variable itself, but it may previously exist, so may be more
700    /// specific due to rewriting.
701    pub fn rewritten_impl_impl(
702        &mut self,
703        id: ImplVarId<'db>,
704        concrete_trait_impl: ConcreteTraitImplId<'db>,
705    ) -> ImplId<'db> {
706        let trait_crate = concrete_trait_impl
707            .trait_impl(self.db)
708            .trait_id(self.db)
709            .parent_module(self.db)
710            .owning_crate(self.db);
711        self.rewritten_impl_item(
712            id,
713            concrete_trait_impl.trait_impl(self.db),
714            |m| &mut m.impls,
715            |inference, stable_ptr| {
716                inference.new_impl_var(
717                    inference.db.concrete_trait_impl_concrete_trait(concrete_trait_impl).unwrap(),
718                    stable_ptr,
719                    ImplLookupContext::new_from_crate(trait_crate).intern(self.db),
720                )
721            },
722        )
723    }
724
725    /// Helper function for getting an impl vars item ids.
726    /// These ids are likely to be variables, but may have more specific information due to
727    /// rewriting.
728    fn rewritten_impl_item<K: Hash + PartialEq + Eq, V: Copy>(
729        &mut self,
730        id: ImplVarId<'db>,
731        key: K,
732        get_map: impl for<'a> Fn(&'a mut ImplVarTraitItemMappings<'db>) -> &'a mut OrderedHashMap<K, V>,
733        new_var: impl FnOnce(&mut Self, Option<SyntaxStablePtrId<'db>>) -> V,
734    ) -> V
735    where
736        Self: SemanticRewriter<V, NoError>,
737    {
738        let var_id = id.id(self.db);
739        if let Some(value) = self
740            .data
741            .impl_vars_trait_item_mappings
742            .get_mut(&var_id)
743            .and_then(|mappings| get_map(mappings).get(&key))
744        {
745            // Copy the value to allow usage of `self`.
746            let value = *value;
747            // If the value already exists, rewrite it before returning.
748            self.rewrite(value).no_err()
749        } else {
750            let value =
751                new_var(self, self.data.stable_ptrs.get(&InferenceVar::Impl(var_id)).cloned());
752            get_map(self.data.impl_vars_trait_item_mappings.entry(var_id).or_default())
753                .insert(key, value);
754            value
755        }
756    }
757
758    /// Sets an error for an impl reduction failure.
759    fn set_impl_reduction_error(&mut self, impl_id: ImplId<'db>) -> ErrorSet {
760        self.set_error(
761            impl_id
762                .concrete_trait(self.db)
763                .map(InferenceError::NoImplsFound)
764                .unwrap_or_else(InferenceError::Reported),
765        )
766    }
767
768    /// Conforms a type to a type. Returning the reduced types on failure.
769    /// Useful for immediately reporting a diagnostic based on the compared types.
770    pub fn conform_ty_for_diag(
771        &mut self,
772        ty0: TypeId<'db>,
773        ty1: TypeId<'db>,
774        diagnostics: &mut SemanticDiagnostics<'db>,
775        diag_stable_ptr: impl FnOnce() -> SyntaxStablePtrId<'db>,
776        diag_kind: impl FnOnce(TypeId<'db>, TypeId<'db>) -> SemanticDiagnosticKind<'db>,
777    ) -> Maybe<()> {
778        match self.conform_ty(ty0, ty1) {
779            Ok(_ty) => Ok(()),
780            Err(err) => {
781                let ty0 = self.rewrite(ty0).no_err();
782                let ty1 = self.rewrite(ty1).no_err();
783                Err(if ty0 != ty1 {
784                    let diag_added = diagnostics.report(diag_stable_ptr(), diag_kind(ty0, ty1));
785                    self.consume_reported_error(err, diag_added);
786                    diag_added
787                } else {
788                    self.report_on_pending_error(err, diagnostics, diag_stable_ptr())
789                })
790            }
791        }
792    }
793
794    /// helper function for ty_contains_var
795    /// Assumes ty was already rewritten.
796    #[doc(hidden)]
797    fn internal_ty_contains_var(&mut self, ty: TypeId<'db>, var: InferenceVar) -> bool {
798        match ty.long(self.db) {
799            TypeLongId::Concrete(concrete) => {
800                let generic_args = concrete.generic_args(self.db);
801                self.generic_args_contain_var(&generic_args, var)
802            }
803            TypeLongId::Tuple(tys) => tys.iter().any(|ty| self.internal_ty_contains_var(*ty, var)),
804            TypeLongId::Snapshot(ty) => self.internal_ty_contains_var(*ty, var),
805            TypeLongId::Var(new_var) => {
806                if InferenceVar::Type(new_var.id) == var {
807                    return true;
808                }
809                if let Some(ty) = self.type_assignment.get(&new_var.id) {
810                    return self.internal_ty_contains_var(*ty, var);
811                }
812                false
813            }
814            TypeLongId::ImplType(id) => self.impl_contains_var(id.impl_id(), var),
815            TypeLongId::GenericParameter(_) | TypeLongId::Missing(_) => false,
816            TypeLongId::Coupon(function_id) => self.function_contains_var(*function_id, var),
817            TypeLongId::FixedSizeArray { type_id, .. } => {
818                self.internal_ty_contains_var(*type_id, var)
819            }
820            TypeLongId::Closure(closure) => {
821                closure
822                    .param_tys
823                    .clone()
824                    .into_iter()
825                    .any(|ty| self.internal_ty_contains_var(ty, var))
826                    || self.internal_ty_contains_var(closure.ret_ty, var)
827            }
828        }
829    }
830
831    /// Creates a var for each constrained impl_type and conforms the types.
832    pub fn conform_generic_params_type_constraints(
833        &mut self,
834        constraints: &[(TypeId<'db>, TypeId<'db>)],
835    ) {
836        let mut impl_type_bounds = Default::default();
837        for (ty0, ty1) in constraints {
838            let ty0 = if let TypeLongId::ImplType(impl_type) = ty0.long(self.db) {
839                self.impl_type_assignment(*impl_type, &mut impl_type_bounds)
840            } else {
841                *ty0
842            };
843            let ty1 = if let TypeLongId::ImplType(impl_type) = ty1.long(self.db) {
844                self.impl_type_assignment(*impl_type, &mut impl_type_bounds)
845            } else {
846                *ty1
847            };
848            self.conform_ty(ty0, ty1).ok();
849        }
850        self.set_impl_type_bounds(impl_type_bounds);
851    }
852
853    /// A helper function for getting for an impl type assignment.
854    /// Creates a new type var if the impl type is not yet assigned.
855    fn impl_type_assignment(
856        &mut self,
857        impl_type: ImplTypeId<'db>,
858        impl_type_bounds: &mut OrderedHashMap<ImplTypeId<'db>, TypeId<'db>>,
859    ) -> TypeId<'db> {
860        match impl_type_bounds.entry(impl_type) {
861            Entry::Occupied(entry) => *entry.get(),
862            Entry::Vacant(entry) => {
863                let inference_id = self.data.inference_id;
864                let id = LocalTypeVarId(self.data.type_vars.len());
865                let var = TypeVar { inference_id, id };
866                let ty = TypeLongId::Var(var).intern(self.db);
867                entry.insert(ty);
868                self.type_vars.push(var);
869                ty
870            }
871        }
872    }
873}