Skip to main content

cairo_lang_semantic/expr/inference/
conform.rs

1use std::hash::Hash;
2
3use cairo_lang_defs::ids::{LanguageElementId, TraitConstantId, TraitTypeId};
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
6use cairo_lang_utils::Intern;
7use cairo_lang_utils::ordered_hash_map::{Entry, OrderedHashMap};
8use itertools::zip_eq;
9
10use super::canonic::{NoError, ResultNoErrEx};
11use super::{
12    ErrorSet, ImplVarId, ImplVarTraitItemMappings, Inference, InferenceError, InferenceResult,
13    InferenceVar, LocalTypeVarId, TypeVar,
14};
15use crate::corelib::{is_valid_literal_type, never_ty};
16use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
17use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
18use crate::items::functions::{GenericFunctionId, ImplGenericFunctionId};
19use crate::items::imp::{
20    ImplId, ImplImplId, ImplLongId, ImplLookupContext, ImplSemantic, NegativeImplId,
21    NegativeImplLongId,
22};
23use crate::items::trt::{ConcreteTraitImplId, TraitSemantic};
24use crate::substitution::SemanticRewriter;
25use crate::types::{ClosureTypeLongId, ImplTypeId, peel_snapshots};
26use crate::{
27    ConcreteFunction, ConcreteImplLongId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId,
28    FunctionId, FunctionLongId, GenericArgumentId, TypeId, TypeLongId,
29};
30
31/// Functions for conforming semantic objects with each other.
32pub trait InferenceConform<'db> {
33    fn conform_ty(&mut self, ty0: TypeId<'db>, ty1: TypeId<'db>) -> InferenceResult<TypeId<'db>>;
34    fn conform_ty_ex(
35        &mut self,
36        ty0: TypeId<'db>,
37        ty1: TypeId<'db>,
38        ty0_is_self: bool,
39    ) -> InferenceResult<(TypeId<'db>, usize)>;
40    fn conform_const(
41        &mut self,
42        ty0: ConstValueId<'db>,
43        ty1: ConstValueId<'db>,
44    ) -> InferenceResult<ConstValueId<'db>>;
45    fn maybe_peel_snapshots(
46        &mut self,
47        ty0_is_self: bool,
48        ty1: TypeId<'db>,
49    ) -> (usize, TypeLongId<'db>);
50    fn conform_generic_args(
51        &mut self,
52        gargs0: &[GenericArgumentId<'db>],
53        gargs1: &[GenericArgumentId<'db>],
54    ) -> InferenceResult<Vec<GenericArgumentId<'db>>>;
55    fn conform_generic_arg(
56        &mut self,
57        garg0: GenericArgumentId<'db>,
58        garg1: GenericArgumentId<'db>,
59    ) -> InferenceResult<GenericArgumentId<'db>>;
60    fn conform_impl(
61        &mut self,
62        impl0: ImplId<'db>,
63        impl1: ImplId<'db>,
64    ) -> InferenceResult<ImplId<'db>>;
65    fn conform_neg_impl(
66        &mut self,
67        neg_impl0: NegativeImplId<'db>,
68        neg_impl1: NegativeImplId<'db>,
69    ) -> InferenceResult<NegativeImplId<'db>>;
70    fn conform_traits(
71        &mut self,
72        trt0: ConcreteTraitId<'db>,
73        trt1: ConcreteTraitId<'db>,
74    ) -> InferenceResult<ConcreteTraitId<'db>>;
75    fn conform_generic_function(
76        &mut self,
77        trt0: GenericFunctionId<'db>,
78        trt1: GenericFunctionId<'db>,
79    ) -> InferenceResult<GenericFunctionId<'db>>;
80    fn ty_contains_var(&mut self, ty: TypeId<'db>, var: InferenceVar) -> bool;
81    fn generic_args_contain_var(
82        &mut self,
83        generic_args: &[GenericArgumentId<'db>],
84        var: InferenceVar,
85    ) -> bool;
86    fn impl_contains_var(&mut self, impl_id: ImplId<'db>, var: InferenceVar) -> bool;
87    fn negative_impl_contains_var(
88        &mut self,
89        neg_impl_id: NegativeImplId<'db>,
90        var: InferenceVar,
91    ) -> bool;
92    fn function_contains_var(&mut self, function_id: FunctionId<'db>, var: InferenceVar) -> bool;
93}
94
95impl<'db> InferenceConform<'db> for Inference<'db, '_> {
96    /// Conforms ty0 to ty1. Should be called when ty0 should be coerced to ty1. Not symmetric.
97    /// Returns the reduced type for ty0, or an error if the type is no coercible.
98    fn conform_ty(&mut self, ty0: TypeId<'db>, ty1: TypeId<'db>) -> InferenceResult<TypeId<'db>> {
99        Ok(self.conform_ty_ex(ty0, ty1, false)?.0)
100    }
101
102    /// Same as conform_ty but supports adding snapshots to ty0 if `ty0_is_self` is true.
103    /// Returns the reduced type for ty0 and the number of snapshots that needs to be added
104    /// for the types to conform.
105    fn conform_ty_ex(
106        &mut self,
107        ty0: TypeId<'db>,
108        ty1: TypeId<'db>,
109        ty0_is_self: bool,
110    ) -> InferenceResult<(TypeId<'db>, usize)> {
111        let ty0 = self.rewrite(ty0).no_err();
112        let ty1 = self.rewrite(ty1).no_err();
113        if ty0 == never_ty(self.db) || ty0.is_missing(self.db) {
114            return Ok((ty1, 0));
115        }
116        if ty0 == ty1 {
117            return Ok((ty0, 0));
118        }
119        let long_ty1 = ty1.long(self.db);
120        match long_ty1 {
121            TypeLongId::Var(var) => return Ok((self.assign_ty(*var, ty0)?, 0)),
122            // Conforming `ty0 -> NumericLiteral(var)`. If `ty0` is also an NumericLiteral or a
123            // Var, fall through so the `long_ty0` match below handles the unification.
124            TypeLongId::NumericLiteral(var)
125                if !matches!(
126                    ty0.long(self.db),
127                    TypeLongId::Var(_) | TypeLongId::NumericLiteral(_)
128                ) =>
129            {
130                return self.bind_integer_literal(*var, ty0);
131            }
132            TypeLongId::Missing(_) => return Ok((ty1, 0)),
133            TypeLongId::Snapshot(inner_ty) if ty0_is_self => {
134                if *inner_ty == ty0 {
135                    return Ok((ty1, 1));
136                }
137                if !matches!(ty0.long(self.db), TypeLongId::Snapshot(_)) {
138                    if let TypeLongId::Var(var) = inner_ty.long(self.db) {
139                        return Ok((self.assign_ty(*var, ty0)?, 1));
140                    }
141                    if let TypeLongId::NumericLiteral(var) = ty0.long(self.db) {
142                        let (bound, n) = self.bind_integer_literal(*var, *inner_ty)?;
143                        return Ok((bound, n + 1));
144                    }
145                }
146            }
147            TypeLongId::ImplType(impl_type) => {
148                if let Some(ty) = self.impl_type_bounds.get(&(*impl_type).into()) {
149                    return self.conform_ty_ex(ty0, *ty, ty0_is_self);
150                }
151            }
152            _ => {}
153        }
154        let long_ty0 = ty0.long(self.db);
155
156        match long_ty0 {
157            TypeLongId::Concrete(concrete0) => {
158                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
159                let TypeLongId::Concrete(concrete1) = long_ty1 else {
160                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
161                };
162                if concrete0.generic_type(self.db) != concrete1.generic_type(self.db) {
163                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
164                }
165                let gargs0 = concrete0.generic_args(self.db);
166                let gargs1 = concrete1.generic_args(self.db);
167                let gargs = self.conform_generic_args(&gargs0, &gargs1)?;
168                let long_ty = TypeLongId::Concrete(ConcreteTypeId::new(
169                    self.db,
170                    concrete0.generic_type(self.db),
171                    gargs,
172                ));
173                Ok((long_ty.intern(self.db), n_snapshots))
174            }
175            TypeLongId::Tuple(tys0) => {
176                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
177                let TypeLongId::Tuple(tys1) = long_ty1 else {
178                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
179                };
180                if tys0.len() != tys1.len() {
181                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
182                }
183                let tys = zip_eq(tys0, tys1)
184                    .map(|(subty0, subty1)| self.conform_ty(*subty0, subty1))
185                    .collect::<Result<Vec<_>, _>>()?;
186                Ok((TypeLongId::Tuple(tys).intern(self.db), n_snapshots))
187            }
188            TypeLongId::Closure(closure0) => {
189                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
190                let TypeLongId::Closure(closure1) = long_ty1 else {
191                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
192                };
193                if closure0.params_location != closure1.params_location {
194                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
195                }
196                let param_tys = zip_eq(closure0.param_tys.clone(), closure1.param_tys)
197                    .map(|(subty0, subty1)| self.conform_ty(subty0, subty1))
198                    .collect::<Result<Vec<_>, _>>()?;
199                let captured_types =
200                    zip_eq(closure0.captured_types.clone(), closure1.captured_types)
201                        .map(|(subty0, subty1)| self.conform_ty(subty0, subty1))
202                        .collect::<Result<Vec<_>, _>>()?;
203                let ret_ty = self.conform_ty(closure0.ret_ty, closure1.ret_ty)?;
204                Ok((
205                    TypeLongId::Closure(ClosureTypeLongId {
206                        param_tys,
207                        ret_ty,
208                        captured_types,
209                        params_location: closure0.params_location,
210                        parent_function: closure0.parent_function,
211                    })
212                    .intern(self.db),
213                    n_snapshots,
214                ))
215            }
216            TypeLongId::FixedSizeArray { type_id, size } => {
217                let (n_snapshots, long_ty1) = self.maybe_peel_snapshots(ty0_is_self, ty1);
218                let TypeLongId::FixedSizeArray { type_id: type_id1, size: size1 } = long_ty1 else {
219                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
220                };
221                let size = self.conform_const(*size, size1)?;
222                let ty = self.conform_ty(*type_id, type_id1)?;
223                Ok((TypeLongId::FixedSizeArray { type_id: ty, size }.intern(self.db), n_snapshots))
224            }
225            TypeLongId::Snapshot(inner_ty0) => {
226                let TypeLongId::Snapshot(inner_ty1) = long_ty1 else {
227                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
228                };
229                let (ty, n_snapshots) = self.conform_ty_ex(*inner_ty0, *inner_ty1, ty0_is_self)?;
230                Ok((TypeLongId::Snapshot(ty).intern(self.db), n_snapshots))
231            }
232            TypeLongId::GenericParameter(_) => {
233                Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }))
234            }
235            TypeLongId::Var(var) => Ok((self.assign_ty(*var, ty1)?, 0)),
236            TypeLongId::NumericLiteral(var) => {
237                // Conforming `NumericLiteral(var) -> ty1`.
238                match long_ty1 {
239                    // Same kind: unify the two integer-literal placeholders by binding their
240                    // underlying type vars.
241                    TypeLongId::NumericLiteral(var2) => Ok((
242                        self.assign_ty(*var, TypeLongId::NumericLiteral(*var2).intern(self.db))?,
243                        0,
244                    )),
245                    // `ty1` is a bare Var: let it bind to the NumericLiteral.
246                    TypeLongId::Var(other) => Ok((self.assign_ty(*other, ty0)?, 0)),
247                    _ => self.bind_integer_literal(*var, ty1),
248                }
249            }
250            TypeLongId::ImplType(impl_type) => {
251                if let Some(ty) = self.impl_type_bounds.get(&(*impl_type).into()) {
252                    return self.conform_ty_ex(*ty, ty1, ty0_is_self);
253                }
254                Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }))
255            }
256            TypeLongId::Missing(_) => Ok((ty0, 0)),
257            TypeLongId::Coupon(function_id0) => {
258                let TypeLongId::Coupon(function_id1) = long_ty1 else {
259                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
260                };
261
262                let func0 = function_id0.long(self.db).function.clone();
263                let func1 = function_id1.long(self.db).function.clone();
264
265                let generic_function =
266                    self.conform_generic_function(func0.generic_function, func1.generic_function)?;
267
268                if func0.generic_args.len() != func1.generic_args.len() {
269                    return Err(self.set_error(InferenceError::TypeKindMismatch { ty0, ty1 }));
270                }
271
272                let generic_args =
273                    self.conform_generic_args(&func0.generic_args, &func1.generic_args)?;
274
275                Ok((
276                    TypeLongId::Coupon(
277                        FunctionLongId {
278                            function: ConcreteFunction { generic_function, generic_args },
279                        }
280                        .intern(self.db),
281                    )
282                    .intern(self.db),
283                    0,
284                ))
285            }
286        }
287    }
288
289    /// Conforms id0 to id1. Should be called when id0 should be coerced to id1. Not symmetric.
290    /// Returns the reduced const for id0, or an error if the const is no coercible.
291    fn conform_const(
292        &mut self,
293        id0: ConstValueId<'db>,
294        id1: ConstValueId<'db>,
295    ) -> InferenceResult<ConstValueId<'db>> {
296        let id0 = self.rewrite(id0).no_err();
297        let id1 = self.rewrite(id1).no_err();
298        self.conform_ty(id0.ty(self.db).unwrap(), id1.ty(self.db).unwrap())?;
299        if id0 == id1 {
300            return Ok(id0);
301        }
302        let const_value0 = id0.long(self.db);
303        if matches!(const_value0, ConstValue::Missing(_)) {
304            return Ok(id1);
305        }
306        match id1.long(self.db) {
307            ConstValue::Missing(_) => return Ok(id1),
308            ConstValue::Var(var, _) => return self.assign_const(*var, id0),
309            _ => {}
310        }
311        match const_value0 {
312            ConstValue::Var(var, _) => Ok(self.assign_const(*var, id1)?),
313            ConstValue::ImplConstant(_) => {
314                Err(self.set_error(InferenceError::ConstKindMismatch { const0: id0, const1: id1 }))
315            }
316            _ => {
317                Err(self.set_error(InferenceError::ConstKindMismatch { const0: id0, const1: id1 }))
318            }
319        }
320    }
321
322    // Conditionally peels snapshots.
323    fn maybe_peel_snapshots(
324        &mut self,
325        ty0_is_self: bool,
326        ty1: TypeId<'db>,
327    ) -> (usize, TypeLongId<'db>) {
328        let (n_snapshots, long_ty1) =
329            if ty0_is_self { peel_snapshots(self.db, ty1) } else { (0, ty1.long(self.db).clone()) };
330        (n_snapshots, long_ty1)
331    }
332
333    /// Conforms generic args. See `conform_ty()`.
334    fn conform_generic_args(
335        &mut self,
336        gargs0: &[GenericArgumentId<'db>],
337        gargs1: &[GenericArgumentId<'db>],
338    ) -> InferenceResult<Vec<GenericArgumentId<'db>>> {
339        zip_eq(gargs0, gargs1)
340            .map(|(garg0, garg1)| self.conform_generic_arg(*garg0, *garg1))
341            .collect::<Result<Vec<_>, _>>()
342    }
343
344    /// Conforms a generic arg. See `conform_ty()`.
345    fn conform_generic_arg(
346        &mut self,
347        garg0: GenericArgumentId<'db>,
348        garg1: GenericArgumentId<'db>,
349    ) -> InferenceResult<GenericArgumentId<'db>> {
350        if garg0 == garg1 {
351            return Ok(garg0);
352        }
353        match garg0 {
354            GenericArgumentId::Type(gty0) => {
355                let GenericArgumentId::Type(gty1) = garg1 else {
356                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
357                };
358                Ok(GenericArgumentId::Type(self.conform_ty(gty0, gty1)?))
359            }
360            GenericArgumentId::Constant(gc0) => {
361                let GenericArgumentId::Constant(gc1) = garg1 else {
362                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
363                };
364
365                Ok(GenericArgumentId::Constant(self.conform_const(gc0, gc1)?))
366            }
367            GenericArgumentId::Impl(impl0) => {
368                let GenericArgumentId::Impl(impl1) = garg1 else {
369                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
370                };
371                Ok(GenericArgumentId::Impl(self.conform_impl(impl0, impl1)?))
372            }
373            GenericArgumentId::NegImpl(neg_impl0) => {
374                let GenericArgumentId::NegImpl(neg_impl1) = garg1 else {
375                    return Err(self.set_error(InferenceError::GenericArgMismatch { garg0, garg1 }));
376                };
377                Ok(GenericArgumentId::NegImpl(self.conform_neg_impl(neg_impl0, neg_impl1)?))
378            }
379        }
380    }
381
382    /// Conforms an impl. See `conform_ty()`.
383    fn conform_impl(
384        &mut self,
385        impl0: ImplId<'db>,
386        impl1: ImplId<'db>,
387    ) -> InferenceResult<ImplId<'db>> {
388        let impl0 = self.rewrite(impl0).no_err();
389        let impl1 = self.rewrite(impl1).no_err();
390        let long_impl1 = impl1.long(self.db);
391        if impl0 == impl1 {
392            return Ok(impl0);
393        }
394        if let ImplLongId::ImplVar(var) = long_impl1 {
395            let impl_concrete_trait = self
396                .db
397                .impl_concrete_trait(impl0)
398                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
399            self.conform_traits(var.long(self.db).concrete_trait_id, impl_concrete_trait)?;
400            let impl_id = self.rewrite(impl0).no_err();
401            return self.assign_impl(*var, impl_id);
402        }
403        match impl0.long(self.db) {
404            ImplLongId::ImplVar(var) => {
405                let impl_concrete_trait = self
406                    .db
407                    .impl_concrete_trait(impl1)
408                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
409                self.conform_traits(var.long(self.db).concrete_trait_id, impl_concrete_trait)?;
410                let impl_id = self.rewrite(impl1).no_err();
411                self.assign_impl(*var, impl_id)
412            }
413            ImplLongId::Concrete(concrete0) => {
414                let ImplLongId::Concrete(concrete1) = long_impl1 else {
415                    return Err(self.set_error(InferenceError::ImplKindMismatch { impl0, impl1 }));
416                };
417                let concrete0 = concrete0.long(self.db);
418                let concrete1 = concrete1.long(self.db);
419                if concrete0.impl_def_id != concrete1.impl_def_id {
420                    return Err(self.set_error(InferenceError::ImplKindMismatch { impl0, impl1 }));
421                }
422                let gargs0 = concrete0.generic_args.clone();
423                let gargs1 = concrete1.generic_args.clone();
424                let generic_args = self.conform_generic_args(&gargs0, &gargs1)?;
425                Ok(ImplLongId::Concrete(
426                    ConcreteImplLongId { impl_def_id: concrete0.impl_def_id, generic_args }
427                        .intern(self.db),
428                )
429                .intern(self.db))
430            }
431            ImplLongId::GenericParameter(_)
432            | ImplLongId::ImplImpl(_)
433            | ImplLongId::SelfImpl(_)
434            | ImplLongId::GeneratedImpl(_) => {
435                Err(self.set_error(InferenceError::ImplKindMismatch { impl0, impl1 }))
436            }
437        }
438    }
439
440    fn conform_neg_impl(
441        &mut self,
442        neg_impl0: NegativeImplId<'db>,
443        neg_impl1: NegativeImplId<'db>,
444    ) -> InferenceResult<NegativeImplId<'db>> {
445        let neg_impl0 = self.rewrite(neg_impl0).no_err();
446        let neg_impl1 = self.rewrite(neg_impl1).no_err();
447        let long_neg_impl1 = neg_impl1.long(self.db);
448        if neg_impl0 == neg_impl1 {
449            return Ok(neg_impl0);
450        }
451        if let NegativeImplLongId::NegativeImplVar(var) = long_neg_impl1 {
452            let impl_concrete_trait = neg_impl0
453                .concrete_trait(self.db)
454                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
455            self.conform_traits(var.long(self.db).concrete_trait_id, impl_concrete_trait)?;
456            let neg_impl_id = self.rewrite(neg_impl0).no_err();
457            return self.assign_neg_impl(*var, neg_impl_id);
458        }
459        if let NegativeImplLongId::NegativeImplVar(var) = neg_impl0.long(self.db) {
460            let impl_concrete_trait = neg_impl1
461                .concrete_trait(self.db)
462                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
463            self.conform_traits(var.long(self.db).concrete_trait_id, impl_concrete_trait)?;
464            let neg_impl_id = self.rewrite(neg_impl1).no_err();
465            return self.assign_neg_impl(*var, neg_impl_id);
466        }
467        // we do not care about multiple negative impls, so we do not check they are of the same
468        // variant, we just conform the traits.
469        let trt0 = neg_impl0
470            .concrete_trait(self.db)
471            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
472        let trt1 = neg_impl1
473            .concrete_trait(self.db)
474            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
475        self.conform_traits(trt0, trt1)?;
476        Ok(self.rewrite(neg_impl1).no_err())
477    }
478
479    /// Conforms generic traits. See `conform_ty()`.
480    fn conform_traits(
481        &mut self,
482        trt0: ConcreteTraitId<'db>,
483        trt1: ConcreteTraitId<'db>,
484    ) -> InferenceResult<ConcreteTraitId<'db>> {
485        let trt0 = trt0.long(self.db);
486        let trt1 = trt1.long(self.db);
487        if trt0.trait_id != trt1.trait_id {
488            return Err(self.set_error(InferenceError::TraitMismatch {
489                trt0: trt0.trait_id,
490                trt1: trt1.trait_id,
491            }));
492        }
493        let generic_args = self.conform_generic_args(&trt0.generic_args, &trt1.generic_args)?;
494        Ok(ConcreteTraitLongId { trait_id: trt0.trait_id, generic_args }.intern(self.db))
495    }
496
497    fn conform_generic_function(
498        &mut self,
499        func0: GenericFunctionId<'db>,
500        func1: GenericFunctionId<'db>,
501    ) -> InferenceResult<GenericFunctionId<'db>> {
502        if let (GenericFunctionId::Impl(id0), GenericFunctionId::Impl(id1)) = (func0, func1) {
503            if id0.function != id1.function {
504                return Err(
505                    self.set_error(InferenceError::GenericFunctionMismatch { func0, func1 })
506                );
507            }
508            let function = id0.function;
509            let impl_id = self.conform_impl(id0.impl_id, id1.impl_id)?;
510            return Ok(GenericFunctionId::Impl(ImplGenericFunctionId { impl_id, function }));
511        }
512
513        if func0 != func1 {
514            return Err(self.set_error(InferenceError::GenericFunctionMismatch { func0, func1 }));
515        }
516        Ok(func0)
517    }
518
519    /// Checks if a type tree contains a certain [InferenceVar] somewhere. Used to avoid inference
520    /// cycles.
521    fn ty_contains_var(&mut self, ty: TypeId<'db>, var: InferenceVar) -> bool {
522        let ty = self.rewrite(ty).no_err();
523        self.internal_ty_contains_var(ty, var)
524    }
525
526    /// Checks if a slice of generic arguments contain a certain [InferenceVar] somewhere. Used to
527    /// avoid inference cycles.
528    fn generic_args_contain_var(
529        &mut self,
530        generic_args: &[GenericArgumentId<'db>],
531        var: InferenceVar,
532    ) -> bool {
533        for garg in generic_args {
534            if match *garg {
535                GenericArgumentId::Type(ty) => self.internal_ty_contains_var(ty, var),
536                GenericArgumentId::Constant(_) => false,
537                GenericArgumentId::Impl(impl_id) => self.impl_contains_var(impl_id, var),
538                GenericArgumentId::NegImpl(neg_impl_id) => {
539                    self.negative_impl_contains_var(neg_impl_id, var)
540                }
541            } {
542                return true;
543            }
544        }
545        false
546    }
547
548    /// Checks if an impl contains a certain [InferenceVar] somewhere. Used to avoid inference
549    /// cycles.
550    fn impl_contains_var(&mut self, impl_id: ImplId<'db>, var: InferenceVar) -> bool {
551        match impl_id.long(self.db) {
552            ImplLongId::Concrete(concrete_impl_id) => {
553                self.generic_args_contain_var(&concrete_impl_id.long(self.db).generic_args, var)
554            }
555            ImplLongId::SelfImpl(concrete_trait_id) => {
556                self.generic_args_contain_var(concrete_trait_id.generic_args(self.db), var)
557            }
558            ImplLongId::GenericParameter(_) => false,
559            ImplLongId::ImplVar(new_var) => {
560                let new_var_long_id = new_var.long(self.db);
561                let new_var_local_id = new_var_long_id.id;
562                if InferenceVar::Impl(new_var_local_id) == var {
563                    return true;
564                }
565                if let Some(impl_id) = self.impl_assignment(new_var_local_id) {
566                    return self.impl_contains_var(impl_id, var);
567                }
568                self.generic_args_contain_var(
569                    new_var_long_id.concrete_trait_id.generic_args(self.db),
570                    var,
571                )
572            }
573            ImplLongId::ImplImpl(impl_impl) => self.impl_contains_var(impl_impl.impl_id(), var),
574            ImplLongId::GeneratedImpl(generated_impl) => self.generic_args_contain_var(
575                generated_impl.concrete_trait(self.db).generic_args(self.db),
576                var,
577            ),
578        }
579    }
580
581    /// Checks if a negative impl contains a certain [InferenceVar] somewhere. Used to avoid
582    /// inference cycles.
583    fn negative_impl_contains_var(
584        &mut self,
585        neg_impl_id: NegativeImplId<'db>,
586        var: InferenceVar,
587    ) -> bool {
588        match neg_impl_id.long(self.db) {
589            NegativeImplLongId::Solved(concrete_trait_id) => {
590                self.generic_args_contain_var(&concrete_trait_id.long(self.db).generic_args, var)
591            }
592            NegativeImplLongId::GenericParameter(_) => false,
593            NegativeImplLongId::NegativeImplVar(new_var) => {
594                let new_var_long_id = new_var.long(self.db);
595                let new_var_local_id = new_var_long_id.id;
596                if InferenceVar::NegativeImpl(new_var_local_id) == var {
597                    return true;
598                }
599                if let Some(neg_impl_id) = self.negative_impl_assignment(new_var_local_id) {
600                    return self.negative_impl_contains_var(neg_impl_id, var);
601                }
602                self.generic_args_contain_var(
603                    new_var_long_id.concrete_trait_id.generic_args(self.db),
604                    var,
605                )
606            }
607        }
608    }
609
610    /// Checks if a function contains a certain [InferenceVar] in its generic arguments or in the
611    /// generic arguments of the impl containing the function (in case the function is an impl
612    /// function).
613    ///
614    /// Used to avoid inference cycles.
615    fn function_contains_var(&mut self, function_id: FunctionId<'db>, var: InferenceVar) -> bool {
616        let function = function_id.get_concrete(self.db);
617        let generic_args = function.generic_args;
618        // Look in the generic arguments of the function and in the impl generic arguments.
619        self.generic_args_contain_var(&generic_args, var)
620            || matches!(function.generic_function,
621                GenericFunctionId::Impl(impl_generic_function_id)
622                if self.impl_contains_var(impl_generic_function_id.impl_id, var)
623            )
624    }
625}
626
627impl<'db> Inference<'db, '_> {
628    /// Binds an `IntegerLiteral`'s inner type variable to `ty`. If `ty` is not a type that
629    /// accepts numeric literals (felt252, primitive integers, `NonZero<…>`, bounded ints …),
630    /// raises `InvalidLiteralType` so the conform-site caller — not the literal site — reports
631    /// the diagnostic.
632    fn bind_integer_literal(
633        &mut self,
634        var: TypeVar<'db>,
635        ty: TypeId<'db>,
636    ) -> InferenceResult<(TypeId<'db>, usize)> {
637        let bound = self.assign_ty(var, ty)?;
638        let rewritten = self.rewrite(bound).no_err();
639        if !rewritten.is_var_free(self.db) || is_valid_literal_type(self.db, rewritten) {
640            return Ok((bound, 0));
641        }
642        Err(self.set_error(InferenceError::InvalidLiteralType(rewritten)))
643    }
644
645    /// Reduces an impl type to a concrete type.
646    pub fn reduce_impl_ty(
647        &mut self,
648        impl_type_id: ImplTypeId<'db>,
649    ) -> InferenceResult<TypeId<'db>> {
650        let impl_id = impl_type_id.impl_id();
651        let trait_ty = impl_type_id.ty();
652        if let ImplLongId::ImplVar(var) = impl_id.long(self.db) {
653            Ok(self.rewritten_impl_type(*var, trait_ty))
654        } else if let Ok(ty) =
655            self.db.impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_ty, self.db))
656        {
657            Ok(ty)
658        } else {
659            Err(self.set_impl_reduction_error(impl_id))
660        }
661    }
662
663    /// Reduces an impl constant to a concrete const.
664    pub fn reduce_impl_constant(
665        &mut self,
666        impl_const_id: ImplConstantId<'db>,
667    ) -> InferenceResult<ConstValueId<'db>> {
668        let impl_id = impl_const_id.impl_id();
669        let trait_constant = impl_const_id.trait_constant_id();
670        if let ImplLongId::ImplVar(var) = impl_id.long(self.db) {
671            Ok(self.rewritten_impl_constant(*var, trait_constant))
672        } else if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
673            ImplConstantId::new(impl_id, trait_constant, self.db),
674        ) {
675            Ok(constant)
676        } else {
677            Err(self.set_impl_reduction_error(impl_id))
678        }
679    }
680
681    /// Reduces an impl impl to a concrete impl.
682    pub fn reduce_impl_impl(
683        &mut self,
684        impl_impl_id: ImplImplId<'db>,
685    ) -> InferenceResult<ImplId<'db>> {
686        let impl_id = impl_impl_id.impl_id();
687        let concrete_trait_impl = impl_impl_id
688            .concrete_trait_impl_id(self.db)
689            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
690
691        if let ImplLongId::ImplVar(var) = impl_id.long(self.db) {
692            Ok(self.rewritten_impl_impl(*var, concrete_trait_impl))
693        } else if let Ok(imp) = self.db.impl_impl_concrete_implized(ImplImplId::new(
694            impl_id,
695            impl_impl_id.trait_impl_id(),
696            self.db,
697        )) {
698            Ok(imp)
699        } else {
700            Err(self.set_impl_reduction_error(impl_id))
701        }
702    }
703
704    /// Returns the type of an impl var's type item.
705    /// The type may be a variable itself, but it may previously exist, so may be more specific due
706    /// to rewriting.
707    pub fn rewritten_impl_type(
708        &mut self,
709        id: ImplVarId<'db>,
710        trait_type_id: TraitTypeId<'db>,
711    ) -> TypeId<'db> {
712        self.rewritten_impl_item(
713            id,
714            trait_type_id,
715            |m| &mut m.types,
716            |inference, stable_ptr| inference.new_type_var(stable_ptr),
717        )
718    }
719
720    /// Returns the constant value of an impl var's constant item.
721    /// The constant may be a variable itself, but it may previously exist, so may be more specific
722    /// due to rewriting.
723    pub fn rewritten_impl_constant(
724        &mut self,
725        id: ImplVarId<'db>,
726        trait_constant: TraitConstantId<'db>,
727    ) -> ConstValueId<'db> {
728        self.rewritten_impl_item(
729            id,
730            trait_constant,
731            |m| &mut m.constants,
732            |inference, stable_ptr| {
733                inference.new_const_var(
734                    stable_ptr,
735                    inference.db.trait_constant_type(trait_constant).unwrap(),
736                )
737            },
738        )
739    }
740
741    /// Returns the inner_impl value of an impl var's impl item.
742    /// The inner_impl may be a variable itself, but it may previously exist, so may be more
743    /// specific due to rewriting.
744    pub fn rewritten_impl_impl(
745        &mut self,
746        id: ImplVarId<'db>,
747        concrete_trait_impl: ConcreteTraitImplId<'db>,
748    ) -> ImplId<'db> {
749        let trait_crate = concrete_trait_impl
750            .trait_impl(self.db)
751            .trait_id(self.db)
752            .parent_module(self.db)
753            .owning_crate(self.db);
754        self.rewritten_impl_item(
755            id,
756            concrete_trait_impl.trait_impl(self.db),
757            |m| &mut m.impls,
758            |inference, stable_ptr| {
759                inference.new_impl_var(
760                    inference.db.concrete_trait_impl_concrete_trait(concrete_trait_impl).unwrap(),
761                    stable_ptr,
762                    ImplLookupContext::new_from_crate(trait_crate).intern(self.db),
763                )
764            },
765        )
766    }
767
768    /// Helper function for getting an impl vars item ids.
769    /// These ids are likely to be variables, but may have more specific information due to
770    /// rewriting.
771    fn rewritten_impl_item<K: Hash + PartialEq + Eq, V: Copy>(
772        &mut self,
773        id: ImplVarId<'db>,
774        key: K,
775        get_map: impl for<'a> Fn(&'a mut ImplVarTraitItemMappings<'db>) -> &'a mut OrderedHashMap<K, V>,
776        new_var: impl FnOnce(&mut Self, Option<SyntaxStablePtrId<'db>>) -> V,
777    ) -> V
778    where
779        Self: SemanticRewriter<V, NoError>,
780    {
781        let var_id = id.id(self.db);
782        if let Some(value) = self
783            .data
784            .impl_vars_trait_item_mappings
785            .get_mut(&var_id)
786            .and_then(|mappings| get_map(mappings).get(&key))
787        {
788            // Copy the value to allow usage of `self`.
789            let value = *value;
790            // If the value already exists, rewrite it before returning.
791            self.rewrite(value).no_err()
792        } else {
793            let value =
794                new_var(self, self.data.stable_ptrs.get(&InferenceVar::Impl(var_id)).cloned());
795            get_map(self.data.impl_vars_trait_item_mappings.entry(var_id).or_default())
796                .insert(key, value);
797            value
798        }
799    }
800
801    /// Sets an error for an impl reduction failure.
802    fn set_impl_reduction_error(&mut self, impl_id: ImplId<'db>) -> ErrorSet {
803        self.set_error(
804            impl_id
805                .concrete_trait(self.db)
806                .map(InferenceError::NoImplsFound)
807                .unwrap_or_else(InferenceError::Reported),
808        )
809    }
810
811    /// Conforms a type to a type. Returning the reduced types on failure.
812    /// Useful for immediately reporting a diagnostic based on the compared types.
813    pub fn conform_ty_for_diag(
814        &mut self,
815        ty0: TypeId<'db>,
816        ty1: TypeId<'db>,
817        diagnostics: &mut SemanticDiagnostics<'db>,
818        diag_stable_ptr: impl FnOnce() -> SyntaxStablePtrId<'db>,
819        diag_kind: impl FnOnce(TypeId<'db>, TypeId<'db>) -> SemanticDiagnosticKind<'db>,
820    ) -> Maybe<()> {
821        match self.conform_ty(ty0, ty1) {
822            Ok(_ty) => Ok(()),
823            Err(err) => {
824                let ty0 = self.rewrite(ty0).no_err();
825                let ty1 = self.rewrite(ty1).no_err();
826                Err(if ty0 != ty1 {
827                    let diag_added = diagnostics.report(diag_stable_ptr(), diag_kind(ty0, ty1));
828                    self.consume_reported_error(err, diag_added);
829                    diag_added
830                } else {
831                    self.report_on_pending_error(err, diagnostics, diag_stable_ptr())
832                })
833            }
834        }
835    }
836
837    /// helper function for ty_contains_var
838    /// Assumes ty was already rewritten.
839    #[doc(hidden)]
840    fn internal_ty_contains_var(&mut self, ty: TypeId<'db>, var: InferenceVar) -> bool {
841        match ty.long(self.db) {
842            TypeLongId::Concrete(concrete) => {
843                let generic_args = concrete.generic_args(self.db);
844                self.generic_args_contain_var(&generic_args, var)
845            }
846            TypeLongId::Tuple(tys) => tys.iter().any(|ty| self.internal_ty_contains_var(*ty, var)),
847            TypeLongId::Snapshot(ty) => self.internal_ty_contains_var(*ty, var),
848            TypeLongId::Var(ty_var) | TypeLongId::NumericLiteral(ty_var) => {
849                if InferenceVar::Type(ty_var.id) == var {
850                    return true;
851                }
852                if let Some(ty) = self.type_assignment.get(&ty_var.id) {
853                    return self.internal_ty_contains_var(*ty, var);
854                }
855                false
856            }
857            TypeLongId::ImplType(id) => self.impl_contains_var(id.impl_id(), var),
858            TypeLongId::GenericParameter(_) | TypeLongId::Missing(_) => false,
859            TypeLongId::Coupon(function_id) => self.function_contains_var(*function_id, var),
860            TypeLongId::FixedSizeArray { type_id, .. } => {
861                self.internal_ty_contains_var(*type_id, var)
862            }
863            TypeLongId::Closure(closure) => {
864                closure.param_tys.iter().any(|ty| self.internal_ty_contains_var(*ty, var))
865                    || self.internal_ty_contains_var(closure.ret_ty, var)
866            }
867        }
868    }
869
870    /// Creates a var for each constrained impl_type and conforms the types.
871    pub fn conform_generic_params_type_constraints(
872        &mut self,
873        constraints: &[(TypeId<'db>, TypeId<'db>)],
874    ) {
875        let mut impl_type_bounds = Default::default();
876        for (ty0, ty1) in constraints {
877            let ty0 = if let TypeLongId::ImplType(impl_type) = ty0.long(self.db) {
878                self.impl_type_assignment(*impl_type, &mut impl_type_bounds)
879            } else {
880                *ty0
881            };
882            let ty1 = if let TypeLongId::ImplType(impl_type) = ty1.long(self.db) {
883                self.impl_type_assignment(*impl_type, &mut impl_type_bounds)
884            } else {
885                *ty1
886            };
887            self.conform_ty(ty0, ty1).ok();
888        }
889        self.set_impl_type_bounds(impl_type_bounds);
890    }
891
892    /// A helper function for getting for an impl type assignment.
893    /// Creates a new type var if the impl type is not yet assigned.
894    fn impl_type_assignment(
895        &mut self,
896        impl_type: ImplTypeId<'db>,
897        impl_type_bounds: &mut OrderedHashMap<ImplTypeId<'db>, TypeId<'db>>,
898    ) -> TypeId<'db> {
899        match impl_type_bounds.entry(impl_type) {
900            Entry::Occupied(entry) => *entry.get(),
901            Entry::Vacant(entry) => {
902                let inference_id = self.data.inference_id;
903                let id = LocalTypeVarId(self.data.type_vars.len());
904                let var = TypeVar { inference_id, id };
905                let ty = TypeLongId::Var(var).intern(self.db);
906                entry.insert(ty);
907                self.type_vars.push(var);
908                ty
909            }
910        }
911    }
912}