cairo_lang_semantic/expr/
inference.rs

1//! Bidirectional type inference.
2
3use std::collections::{HashMap, VecDeque};
4use std::hash::Hash;
5use std::mem;
6use std::ops::{Deref, DerefMut};
7
8use cairo_lang_debug::DebugWithDb;
9use cairo_lang_defs::ids::{
10    ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId,
11    GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LanguageElementId,
12    LocalVarId, LookupItemId, MemberId, ParamId, StructId, TraitConstantId, TraitFunctionId,
13    TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
14};
15use cairo_lang_diagnostics::{DiagnosticAdded, Maybe, skip_diagnostic};
16use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
17use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
18use cairo_lang_utils::ordered_hash_map::{Entry, OrderedHashMap};
19use cairo_lang_utils::{
20    Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches,
21};
22
23use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
24use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
25use crate::corelib::{CoreTraitContext, core_felt252_ty, get_core_trait, numeric_literal_trait};
26use crate::db::SemanticGroup;
27use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
28use crate::expr::inference::canonic::ResultNoErrEx;
29use crate::expr::inference::conform::InferenceConform;
30use crate::expr::objects::*;
31use crate::expr::pattern::*;
32use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
33use crate::items::functions::{
34    ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
35    GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
36    ImplGenericFunctionWithBodyId,
37};
38use crate::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType};
39use crate::items::imp::{
40    GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
41    ImplLookupContext, UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
42};
43use crate::items::trt::{ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId};
44use crate::substitution::{HasDb, RewriteResult, SemanticRewriter, SubstitutionRewriter};
45use crate::types::{
46    ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
47    ImplTypeId,
48};
49use crate::{
50    ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
51    ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
52    FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
53    Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
54    add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
55};
56
57pub mod canonic;
58pub mod conform;
59pub mod infers;
60pub mod solver;
61
62/// A type variable, created when a generic type argument is not passed, and thus is not known
63/// yet and needs to be inferred.
64#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
65pub struct TypeVar {
66    pub inference_id: InferenceId,
67    pub id: LocalTypeVarId,
68}
69
70/// A const variable, created when a generic const argument is not passed, and thus is not known
71/// yet and needs to be inferred.
72#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
73pub struct ConstVar {
74    pub inference_id: InferenceId,
75    pub id: LocalConstVarId,
76}
77
78/// An id for an inference context. Each inference variable is associated with an inference id.
79#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
80#[debug_db(dyn SemanticGroup + 'static)]
81pub enum InferenceId {
82    LookupItemDeclaration(LookupItemId),
83    LookupItemGenerics(LookupItemId),
84    LookupItemDefinition(LookupItemId),
85    ImplDefTrait(ImplDefId),
86    ImplAliasImplDef(ImplAliasId),
87    GenericParam(GenericParamId),
88    GenericImplParamTrait(GenericParamId),
89    GlobalUseStar(GlobalUseId),
90    Canonical,
91    /// For resolving that will not be used anywhere in the semantic model.
92    NoContext,
93}
94
95/// An impl variable, created when a generic type argument is not passed, and thus is not known
96/// yet and needs to be inferred.
97#[derive(Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
98#[debug_db(dyn SemanticGroup + 'static)]
99pub struct ImplVar {
100    pub inference_id: InferenceId,
101    #[dont_rewrite]
102    pub id: LocalImplVarId,
103    pub concrete_trait_id: ConcreteTraitId,
104    #[dont_rewrite]
105    pub lookup_context: ImplLookupContext,
106}
107impl ImplVar {
108    pub fn intern(&self, db: &dyn SemanticGroup) -> ImplVarId {
109        self.clone().intern(db)
110    }
111}
112
113#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
114pub struct LocalTypeVarId(pub usize);
115#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
116pub struct LocalImplVarId(pub usize);
117
118#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
119pub struct LocalConstVarId(pub usize);
120
121define_short_id!(ImplVarId, ImplVar, SemanticGroup, lookup_intern_impl_var, intern_impl_var);
122impl ImplVarId {
123    pub fn id(&self, db: &dyn SemanticGroup) -> LocalImplVarId {
124        self.lookup_intern(db).id
125    }
126    pub fn concrete_trait_id(&self, db: &dyn SemanticGroup) -> ConcreteTraitId {
127        self.lookup_intern(db).concrete_trait_id
128    }
129    pub fn lookup_context(&self, db: &dyn SemanticGroup) -> ImplLookupContext {
130        self.lookup_intern(db).lookup_context
131    }
132}
133semantic_object_for_id!(ImplVarId, lookup_intern_impl_var, intern_impl_var, ImplVar);
134
135#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject)]
136pub enum InferenceVar {
137    Type(LocalTypeVarId),
138    Const(LocalConstVarId),
139    Impl(LocalImplVarId),
140}
141
142// TODO(spapini): Add to diagnostics.
143#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb)]
144#[debug_db(dyn SemanticGroup + 'static)]
145pub enum InferenceError {
146    /// An inference error wrapping a previously reported error.
147    Reported(DiagnosticAdded),
148    Cycle(InferenceVar),
149    TypeKindMismatch {
150        ty0: TypeId,
151        ty1: TypeId,
152    },
153    ConstKindMismatch {
154        const0: ConstValueId,
155        const1: ConstValueId,
156    },
157    ImplKindMismatch {
158        impl0: ImplId,
159        impl1: ImplId,
160    },
161    GenericArgMismatch {
162        garg0: GenericArgumentId,
163        garg1: GenericArgumentId,
164    },
165    TraitMismatch {
166        trt0: TraitId,
167        trt1: TraitId,
168    },
169    GenericFunctionMismatch {
170        func0: GenericFunctionId,
171        func1: GenericFunctionId,
172    },
173    ConstInferenceNotSupported,
174
175    // TODO(spapini): These are only used for external interface. Separate them along with the
176    // finalize() function to a wrapper.
177    NoImplsFound(ConcreteTraitId),
178    Ambiguity(Ambiguity),
179    TypeNotInferred(TypeId),
180}
181impl InferenceError {
182    pub fn format(&self, db: &(dyn SemanticGroup + 'static)) -> String {
183        match self {
184            InferenceError::Reported(_) => "Inference error occurred.".into(),
185            InferenceError::Cycle(_var) => "Inference cycle detected".into(),
186            InferenceError::TypeKindMismatch { ty0, ty1 } => {
187                format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
188            }
189            InferenceError::ConstKindMismatch { const0, const1 } => {
190                format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
191            }
192            InferenceError::ImplKindMismatch { impl0, impl1 } => {
193                format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
194            }
195            InferenceError::GenericArgMismatch { garg0, garg1 } => {
196                format!(
197                    "Generic arg mismatch: `{:?}` and `{:?}`.",
198                    garg0.debug(db),
199                    garg1.debug(db)
200                )
201            }
202            InferenceError::TraitMismatch { trt0, trt1 } => {
203                format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
204            }
205            InferenceError::ConstInferenceNotSupported => {
206                "Const generic inference not yet supported.".into()
207            }
208            InferenceError::NoImplsFound(concrete_trait_id) => {
209                let trait_id = concrete_trait_id.trait_id(db);
210                if trait_id == numeric_literal_trait(db) {
211                    let generic_type = extract_matches!(
212                        concrete_trait_id.generic_args(db)[0],
213                        GenericArgumentId::Type
214                    );
215                    return format!(
216                        "Mismatched types. The type `{:?}` cannot be created from a numeric \
217                         literal.",
218                        generic_type.debug(db)
219                    );
220                } else if trait_id
221                    == get_core_trait(db, CoreTraitContext::TopLevel, "StringLiteral".into())
222                {
223                    let generic_type = extract_matches!(
224                        concrete_trait_id.generic_args(db)[0],
225                        GenericArgumentId::Type
226                    );
227                    return format!(
228                        "Mismatched types. The type `{:?}` cannot be created from a string \
229                         literal.",
230                        generic_type.debug(db)
231                    );
232                }
233                format!(
234                    "Trait has no implementation in context: {:?}.",
235                    concrete_trait_id.debug(db)
236                )
237            }
238            InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
239            InferenceError::TypeNotInferred(ty) => {
240                format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
241            }
242            InferenceError::GenericFunctionMismatch { func0, func1 } => {
243                format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
244            }
245        }
246    }
247}
248
249impl InferenceError {
250    pub fn report(
251        &self,
252        diagnostics: &mut SemanticDiagnostics,
253        stable_ptr: SyntaxStablePtrId,
254    ) -> DiagnosticAdded {
255        match self {
256            InferenceError::Reported(diagnostic_added) => *diagnostic_added,
257            _ => diagnostics
258                .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
259        }
260    }
261}
262
263/// This struct is used to ensure that when an inference error occurs, it is properly set in the
264/// `Inference` object, and then properly consumed.
265///
266/// It must not be constructed directly. Instead, it is returned by [Inference::set_error].
267#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
268pub struct ErrorSet;
269
270pub type InferenceResult<T> = Result<T, ErrorSet>;
271
272#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
273pub enum InferenceErrorStatus {
274    Pending,
275    Consumed,
276}
277
278/// A mapping of an impl var's trait items to concrete items
279#[derive(Debug, Default, PartialEq, Eq, Clone, SemanticObject)]
280pub struct ImplVarTraitItemMappings {
281    /// The trait types of the impl var.
282    types: OrderedHashMap<TraitTypeId, TypeId>,
283    /// The trait constants of the impl var.
284    constants: OrderedHashMap<TraitConstantId, ConstValueId>,
285    /// The trait impls of the impl var.
286    impls: OrderedHashMap<TraitImplId, ImplId>,
287}
288impl Hash for ImplVarTraitItemMappings {
289    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
290        self.types.iter().for_each(|(trait_type_id, type_id)| {
291            trait_type_id.hash(state);
292            type_id.hash(state);
293        });
294        self.constants.iter().for_each(|(trait_const_id, const_id)| {
295            trait_const_id.hash(state);
296            const_id.hash(state);
297        });
298        self.impls.iter().for_each(|(trait_impl_id, impl_id)| {
299            trait_impl_id.hash(state);
300            impl_id.hash(state);
301        });
302    }
303}
304
305/// State of inference.
306#[derive(Debug, DebugWithDb, PartialEq, Eq)]
307#[debug_db(dyn SemanticGroup + 'static)]
308pub struct InferenceData {
309    pub inference_id: InferenceId,
310    /// Current inferred assignment for type variables.
311    pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId>,
312    /// Current inferred assignment for const variables.
313    pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId>,
314    /// Current inferred assignment for impl variables.
315    pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId>,
316    /// Unsolved impl variables mapping to a maps of trait items to a corresponding item variable.
317    /// Upon solution of the trait conforms the fully known item to the variable.
318    pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings>,
319    /// Type variables.
320    pub type_vars: Vec<TypeVar>,
321    /// Const variables.
322    pub const_vars: Vec<ConstVar>,
323    /// Impl variables.
324    pub impl_vars: Vec<ImplVar>,
325    /// Mapping from variables to stable pointers, if exist.
326    pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId>,
327    /// Inference variables that are pending to be solved.
328    pending: VecDeque<LocalImplVarId>,
329    /// Inference variables that have been refuted - no solutions exist.
330    refuted: Vec<LocalImplVarId>,
331    /// Inference variables that have been solved.
332    solved: Vec<LocalImplVarId>,
333    /// Inference variables that are currently ambiguous. May be solved later.
334    ambiguous: Vec<(LocalImplVarId, Ambiguity)>,
335    /// Mapping from impl types to type variables.
336    pub impl_type_bounds: OrderedHashMap<ImplTypeId, TypeId>,
337
338    // Error handling members.
339    /// The current error status.
340    pub error_status: Result<(), InferenceErrorStatus>,
341    /// `Some` only when error_state is Err(Pending).
342    error: Option<InferenceError>,
343    /// `Some` only when error_state is Err(Consumed).
344    consumed_error: Option<DiagnosticAdded>,
345}
346impl InferenceData {
347    pub fn new(inference_id: InferenceId) -> Self {
348        Self {
349            inference_id,
350            type_assignment: OrderedHashMap::default(),
351            impl_assignment: OrderedHashMap::default(),
352            const_assignment: OrderedHashMap::default(),
353            impl_vars_trait_item_mappings: HashMap::new(),
354            type_vars: Vec::new(),
355            impl_vars: Vec::new(),
356            const_vars: Vec::new(),
357            stable_ptrs: HashMap::new(),
358            pending: VecDeque::new(),
359            refuted: Vec::new(),
360            solved: Vec::new(),
361            ambiguous: Vec::new(),
362            impl_type_bounds: OrderedHashMap::default(),
363            error_status: Ok(()),
364            error: None,
365            consumed_error: None,
366        }
367    }
368    pub fn inference<'db, 'b: 'db>(&'db mut self, db: &'b dyn SemanticGroup) -> Inference<'db> {
369        Inference::new(db, self)
370    }
371    pub fn clone_with_inference_id(
372        &self,
373        db: &dyn SemanticGroup,
374        inference_id: InferenceId,
375    ) -> InferenceData {
376        let mut inference_id_replacer =
377            InferenceIdReplacer::new(db, self.inference_id, inference_id);
378        Self {
379            inference_id,
380            type_assignment: self
381                .type_assignment
382                .iter()
383                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
384                .collect(),
385            const_assignment: self
386                .const_assignment
387                .iter()
388                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
389                .collect(),
390            impl_assignment: self
391                .impl_assignment
392                .iter()
393                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
394                .collect(),
395            impl_vars_trait_item_mappings: self
396                .impl_vars_trait_item_mappings
397                .iter()
398                .map(|(k, mappings)| {
399                    (*k, ImplVarTraitItemMappings {
400                        types: mappings
401                            .types
402                            .iter()
403                            .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
404                            .collect(),
405                        constants: mappings
406                            .constants
407                            .iter()
408                            .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
409                            .collect(),
410                        impls: mappings
411                            .impls
412                            .iter()
413                            .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
414                            .collect(),
415                    })
416                })
417                .collect(),
418            type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
419            const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
420            impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
421            stable_ptrs: self.stable_ptrs.clone(),
422            pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
423            refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
424            solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
425            ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
426            impl_type_bounds: self
427                .impl_type_bounds
428                .iter()
429                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
430                .collect(),
431            error_status: self.error_status,
432            error: self.error.clone(),
433            consumed_error: self.consumed_error,
434        }
435    }
436    pub fn temporary_clone(&self) -> InferenceData {
437        Self {
438            inference_id: self.inference_id,
439            type_assignment: self.type_assignment.clone(),
440            const_assignment: self.const_assignment.clone(),
441            impl_assignment: self.impl_assignment.clone(),
442            impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
443            type_vars: self.type_vars.clone(),
444            const_vars: self.const_vars.clone(),
445            impl_vars: self.impl_vars.clone(),
446            stable_ptrs: self.stable_ptrs.clone(),
447            pending: self.pending.clone(),
448            refuted: self.refuted.clone(),
449            solved: self.solved.clone(),
450            ambiguous: self.ambiguous.clone(),
451            impl_type_bounds: self.impl_type_bounds.clone(),
452            error_status: self.error_status,
453            error: self.error.clone(),
454            consumed_error: self.consumed_error,
455        }
456    }
457}
458
459/// State of inference. A system of inference constraints.
460pub struct Inference<'db> {
461    db: &'db dyn SemanticGroup,
462    pub data: &'db mut InferenceData,
463}
464
465impl Deref for Inference<'_> {
466    type Target = InferenceData;
467
468    fn deref(&self) -> &Self::Target {
469        self.data
470    }
471}
472impl DerefMut for Inference<'_> {
473    fn deref_mut(&mut self) -> &mut Self::Target {
474        self.data
475    }
476}
477
478impl std::fmt::Debug for Inference<'_> {
479    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480        let x = self.data.debug(self.db.elongate());
481        write!(f, "{x:?}")
482    }
483}
484
485impl<'db> Inference<'db> {
486    fn new(db: &'db dyn SemanticGroup, data: &'db mut InferenceData) -> Self {
487        Self { db, data }
488    }
489
490    /// Getter for an [ImplVar].
491    fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar {
492        &self.impl_vars[var_id.0]
493    }
494
495    /// Getter for an impl var assignment.
496    pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId> {
497        self.impl_assignment.get(&var_id).copied()
498    }
499
500    /// Getter for a type var assignment.
501    fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId> {
502        self.type_assignment.get(&var_id).copied()
503    }
504
505    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
506    /// Returns a wrapping TypeId.
507    pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeId {
508        let var = self.new_type_var_raw(stable_ptr);
509
510        TypeLongId::Var(var).intern(self.db)
511    }
512
513    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
514    /// Returns the variable id.
515    pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeVar {
516        let var =
517            TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
518        if let Some(stable_ptr) = stable_ptr {
519            self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
520        }
521        self.type_vars.push(var);
522        var
523    }
524
525    /// Getter for an impl type assignment.
526    /// Creates a new type var if the impl type is not yet assigned.
527    pub fn impl_type_assignment(&mut self, impl_type: ImplTypeId) -> TypeId {
528        match self.data.impl_type_bounds.entry(impl_type) {
529            Entry::Occupied(entry) => *entry.get(),
530            Entry::Vacant(entry) => {
531                let inference_id = self.data.inference_id;
532                let id = LocalTypeVarId(self.data.type_vars.len());
533                let var = TypeVar { inference_id, id };
534                let ty = TypeLongId::Var(var).intern(self.db);
535                entry.insert(ty);
536                self.type_vars.push(var);
537                ty
538            }
539        }
540    }
541
542    /// If an impl type is not fully inferred, we assign it's var to the original type
543    pub fn finalize_impl_type_bounds(&mut self) {
544        let mut impl_type_bounds = std::mem::take(&mut self.data.impl_type_bounds);
545        impl_type_bounds.retain(|impl_type, ty| {
546            if !matches!(self.rewrite(ty.lookup_intern(self.db)).no_err(), TypeLongId::Var(_)) {
547                return true;
548            }
549
550            self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
551            false
552        });
553        self.data.impl_type_bounds = impl_type_bounds;
554    }
555
556    /// Allocates a new [ConstVar] for an unknown consts that needs to be inferred.
557    /// Returns a wrapping [ConstValueId].
558    pub fn new_const_var(
559        &mut self,
560        stable_ptr: Option<SyntaxStablePtrId>,
561        ty: TypeId,
562    ) -> ConstValueId {
563        let var = self.new_const_var_raw(stable_ptr);
564        ConstValue::Var(var, ty).intern(self.db)
565    }
566
567    /// Allocates a new [ConstVar] for an unknown type that needs to be inferred.
568    /// Returns the variable id.
569    pub fn new_const_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> ConstVar {
570        let var = ConstVar {
571            inference_id: self.inference_id,
572            id: LocalConstVarId(self.const_vars.len()),
573        };
574        if let Some(stable_ptr) = stable_ptr {
575            self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
576        }
577        self.const_vars.push(var);
578        var
579    }
580
581    /// Allocates a new [ImplVar] for an unknown type that needs to be inferred.
582    /// Returns a wrapping ImplId.
583    pub fn new_impl_var(
584        &mut self,
585        concrete_trait_id: ConcreteTraitId,
586        stable_ptr: Option<SyntaxStablePtrId>,
587        lookup_context: ImplLookupContext,
588    ) -> ImplId {
589        let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
590        ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
591    }
592
593    /// Allocates a new [ImplVar] for an unknown type that needs to be inferred.
594    /// Returns the variable id.
595    fn new_impl_var_raw(
596        &mut self,
597        lookup_context: ImplLookupContext,
598        concrete_trait_id: ConcreteTraitId,
599        stable_ptr: Option<SyntaxStablePtrId>,
600    ) -> LocalImplVarId {
601        let mut lookup_context = lookup_context;
602        lookup_context
603            .insert_module(concrete_trait_id.trait_id(self.db).module_file_id(self.db.upcast()).0);
604
605        let id = LocalImplVarId(self.impl_vars.len());
606        if let Some(stable_ptr) = stable_ptr {
607            self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
608        }
609        let var =
610            ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
611        self.impl_vars.push(var);
612        self.pending.push_back(id);
613        id
614    }
615
616    /// Solves the inference system. After a successful solve, there are no more pending impl
617    /// inferences.
618    /// Returns whether the inference was successful. If not, the error may be found by
619    /// `.error_state()`.
620    pub fn solve(&mut self) -> InferenceResult<()> {
621        self.solve_ex().map_err(|(err_set, _)| err_set)
622    }
623
624    /// Same as `solve`, but returns the error stable pointer if an error occurred.
625    fn solve_ex(&mut self) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
626        let mut ambiguous = std::mem::take(&mut self.ambiguous);
627        self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
628        while let Some(var) = self.pending.pop_front() {
629            // First inference error stops inference.
630            self.solve_single_pending(var).map_err(|err_set| {
631                (err_set, self.stable_ptrs.get(&InferenceVar::Impl(var)).copied())
632            })?;
633        }
634        Ok(())
635    }
636
637    fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
638        if self.impl_assignment.contains_key(&var) {
639            return Ok(());
640        }
641        let solution = match self.impl_var_solution_set(var)? {
642            SolutionSet::None => {
643                self.refuted.push(var);
644                return Ok(());
645            }
646            SolutionSet::Ambiguous(ambiguity) => {
647                self.ambiguous.push((var, ambiguity));
648                return Ok(());
649            }
650            SolutionSet::Unique(solution) => solution,
651        };
652
653        // Solution found. Assign it.
654        self.assign_local_impl(var, solution)?;
655
656        // Something changed.
657        self.solved.push(var);
658        let mut ambiguous = std::mem::take(&mut self.ambiguous);
659        self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
660
661        Ok(())
662    }
663
664    /// Returns the solution set status for the inference:
665    /// Whether there is a unique solution, multiple solutions, no solutions or an error.
666    pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<()>> {
667        self.solve()?;
668        if !self.refuted.is_empty() {
669            return Ok(SolutionSet::None);
670        }
671        if let Some((_, ambiguity)) = self.ambiguous.first() {
672            return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
673        }
674        assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
675        Ok(SolutionSet::Unique(()))
676    }
677
678    /// Finalizes the inference by inferring uninferred numeric literals as felt252.
679    /// Returns an error and does not report it.
680    pub fn finalize_without_reporting(
681        &mut self,
682    ) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
683        if self.error_status.is_err() {
684            // TODO(yuval): consider adding error location to the set error.
685            return Err((ErrorSet, None));
686        }
687
688        let numeric_trait_id = numeric_literal_trait(self.db);
689        let felt_ty = core_felt252_ty(self.db);
690
691        // Conform all uninferred numeric literals to felt252.
692        loop {
693            let mut changed = false;
694            self.solve_ex()?;
695            for (var, _) in self.ambiguous.clone() {
696                let impl_var = self.impl_var(var).clone();
697                if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
698                    continue;
699                }
700                // Uninferred numeric trait. Resolve as felt252.
701                let ty = extract_matches!(
702                    impl_var.concrete_trait_id.generic_args(self.db)[0],
703                    GenericArgumentId::Type
704                );
705                if self.rewrite(ty).no_err() == felt_ty {
706                    continue;
707                }
708                self.conform_ty(ty, felt_ty).map_err(|err_set| {
709                    (err_set, self.stable_ptrs.get(&InferenceVar::Impl(impl_var.id)).copied())
710                })?;
711                changed = true;
712                break;
713            }
714            if !changed {
715                break;
716            }
717        }
718        assert!(
719            self.pending.is_empty(),
720            "pending should all be solved by this point. Guaranteed by solve()."
721        );
722
723        let Some((var, err)) = self.first_undetermined_variable() else {
724            return Ok(());
725        };
726        Err((self.set_error(err), self.stable_ptrs.get(&var).copied()))
727    }
728
729    /// Finalizes the inference and report diagnostics if there are any errors.
730    /// All the remaining type vars are mapped to the `missing` type, to prevent additional
731    /// diagnostics.
732    pub fn finalize(
733        &mut self,
734        diagnostics: &mut SemanticDiagnostics,
735        stable_ptr: SyntaxStablePtrId,
736    ) {
737        if let Err((err_set, err_stable_ptr)) = self.finalize_without_reporting() {
738            let diag = self.report_on_pending_error(
739                err_set,
740                diagnostics,
741                err_stable_ptr.unwrap_or(stable_ptr),
742            );
743
744            let ty_missing = TypeId::missing(self.db, diag);
745            for var in &self.data.type_vars {
746                self.data.type_assignment.entry(var.id).or_insert(ty_missing);
747            }
748        }
749    }
750
751    /// Retrieves the first variable that is still not inferred, or None, if everything is
752    /// inferred.
753    /// Does not set the error but return it, which is ok as this is a private helper function.
754    fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError)> {
755        for (id, var) in self.type_vars.iter().enumerate() {
756            if self.type_assignment(LocalTypeVarId(id)).is_none() {
757                let ty = TypeLongId::Var(*var).intern(self.db);
758                return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
759            }
760        }
761        if let Some(var) = self.refuted.first().copied() {
762            let impl_var = self.impl_var(var).clone();
763            let concrete_trait_id = impl_var.concrete_trait_id;
764            let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
765            return Some((
766                InferenceVar::Impl(var),
767                InferenceError::NoImplsFound(concrete_trait_id),
768            ));
769        }
770        if let Some((var, ambiguity)) = self.ambiguous.first() {
771            let var = *var;
772            // Note: do not rewrite `ambiguity`, since it is expressed in canonical variables.
773            return Some((InferenceVar::Impl(var), InferenceError::Ambiguity(ambiguity.clone())));
774        }
775        None
776    }
777
778    /// Assigns a value to a local impl variable id. See assign_impl().
779    fn assign_local_impl(
780        &mut self,
781        var: LocalImplVarId,
782        impl_id: ImplId,
783    ) -> InferenceResult<ImplId> {
784        let concrete_trait = impl_id
785            .concrete_trait(self.db)
786            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
787        self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
788        if let Some(other_impl) = self.impl_assignment(var) {
789            return self.conform_impl(impl_id, other_impl);
790        }
791        if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
792        {
793            return Err(self.set_error(InferenceError::Cycle(InferenceVar::Impl(var))));
794        }
795        self.impl_assignment.insert(var, impl_id);
796        if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
797            for (trait_ty, ty) in mappings.types {
798                self.conform_ty(
799                    ty,
800                    self.db
801                        .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_ty, self.db))
802                        .map_err(|_| ErrorSet)?,
803                )?;
804            }
805            for (trait_constant, constant_id) in mappings.constants {
806                self.conform_const(
807                    constant_id,
808                    self.db
809                        .impl_constant_concrete_implized_value(ImplConstantId::new(
810                            impl_id,
811                            trait_constant,
812                            self.db,
813                        ))
814                        .map_err(|_| ErrorSet)?,
815                )?;
816            }
817            for (trait_impl, inner_impl_id) in mappings.impls {
818                self.conform_impl(
819                    inner_impl_id,
820                    self.db
821                        .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
822                        .map_err(|_| ErrorSet)?,
823                )?;
824            }
825        }
826        Ok(impl_id)
827    }
828
829    /// Tries to assigns value to an [ImplVarId]. Return the assigned impl, or an error.
830    fn assign_impl(&mut self, var_id: ImplVarId, impl_id: ImplId) -> InferenceResult<ImplId> {
831        let var = var_id.lookup_intern(self.db);
832        if var.inference_id != self.inference_id {
833            return Err(self.set_error(InferenceError::ImplKindMismatch {
834                impl0: ImplLongId::ImplVar(var_id).intern(self.db),
835                impl1: impl_id,
836            }));
837        }
838        self.assign_local_impl(var.id, impl_id)
839    }
840
841    /// Assigns a value to a [TypeVar]. Return the assigned type, or an error.
842    /// Assumes the variable is not already assigned.
843    fn assign_ty(&mut self, var: TypeVar, ty: TypeId) -> InferenceResult<TypeId> {
844        if var.inference_id != self.inference_id {
845            return Err(self.set_error(InferenceError::TypeKindMismatch {
846                ty0: TypeLongId::Var(var).intern(self.db),
847                ty1: ty,
848            }));
849        }
850        assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
851        let inference_var = InferenceVar::Type(var.id);
852        if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
853            return Err(self.set_error(InferenceError::Cycle(inference_var)));
854        }
855        // If assigning var to var - making sure assigning to the lower id for proper canonization.
856        if let TypeLongId::Var(other) = ty.lookup_intern(self.db) {
857            if other.inference_id == self.inference_id && other.id.0 > var.id.0 {
858                let var_ty = TypeLongId::Var(var).intern(self.db);
859                self.type_assignment.insert(other.id, var_ty);
860                return Ok(var_ty);
861            }
862        }
863        self.type_assignment.insert(var.id, ty);
864        Ok(ty)
865    }
866
867    /// Assigns a value to a [ConstVar]. Return the assigned const, or an error.
868    /// Assumes the variable is not already assigned.
869    fn assign_const(&mut self, var: ConstVar, id: ConstValueId) -> InferenceResult<ConstValueId> {
870        if var.inference_id != self.inference_id {
871            return Err(self.set_error(InferenceError::ConstKindMismatch {
872                const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
873                    .intern(self.db),
874                const1: id,
875            }));
876        }
877
878        self.const_assignment.insert(var.id, id);
879        Ok(id)
880    }
881
882    /// Computes the solution set for an impl variable with a recursive query.
883    fn impl_var_solution_set(
884        &mut self,
885        var: LocalImplVarId,
886    ) -> InferenceResult<SolutionSet<ImplId>> {
887        let impl_var = self.impl_var(var).clone();
888        // Update the concrete trait of the impl var.
889        let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
890        self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
891        let impl_var_trait_item_mappings =
892            self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
893        let solution_set = self.trait_solution_set(
894            concrete_trait_id,
895            impl_var_trait_item_mappings,
896            impl_var.lookup_context,
897        )?;
898        Ok(match solution_set {
899            SolutionSet::None => SolutionSet::None,
900            SolutionSet::Unique((canonical_impl, canonicalizer)) => {
901                SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
902            }
903            SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
904        })
905    }
906
907    /// Computes the solution set for a trait with a recursive query.
908    pub fn trait_solution_set(
909        &mut self,
910        concrete_trait_id: ConcreteTraitId,
911        impl_var_trait_item_mappings: ImplVarTraitItemMappings,
912        mut lookup_context: ImplLookupContext,
913    ) -> InferenceResult<SolutionSet<(CanonicalImpl, CanonicalMapping)>> {
914        let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
915        // TODO(spapini): This is done twice. Consider doing it only here.
916        let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
917        enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
918
919        // Don't try to resolve impls if the first generic param is a variable.
920        let generic_args = concrete_trait_id.generic_args(self.db);
921        match generic_args.first() {
922            Some(GenericArgumentId::Type(ty)) => {
923                if let TypeLongId::Var(_) = ty.lookup_intern(self.db) {
924                    // Don't try to infer such impls.
925                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
926                }
927            }
928            Some(GenericArgumentId::Impl(imp)) => {
929                // Don't try to infer such impls.
930                if let ImplLongId::ImplVar(_) = imp.lookup_intern(self.db) {
931                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
932                }
933            }
934            Some(GenericArgumentId::Constant(const_value)) => {
935                if let ConstValue::Var(_, _) = const_value.lookup_intern(self.db) {
936                    // Don't try to infer such impls.
937                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
938                }
939            }
940            _ => {}
941        };
942
943        let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
944            self.db,
945            self.inference_id,
946            concrete_trait_id,
947            impl_var_trait_item_mappings,
948        );
949        let solution_set = match self.db.canonic_trait_solutions(canonical_trait, lookup_context) {
950            Ok(solution_set) => solution_set,
951            Err(err) => return Err(self.set_error(err)),
952        };
953        match solution_set {
954            SolutionSet::None => Ok(SolutionSet::None),
955            SolutionSet::Unique(canonical_impl) => {
956                Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
957            }
958            SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
959        }
960    }
961
962    /// Validate that the given impl is valid based on its negative impls arguments.
963    /// Returns `SolutionSet::Unique(canonical_impl)` if the impl is valid and
964    /// SolutionSet::Ambiguous(...) otherwise.
965    fn validate_neg_impls(
966        &mut self,
967        lookup_context: &ImplLookupContext,
968        canonical_impl: CanonicalImpl,
969    ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
970        /// Validates that no solution set is found for the negative impls.
971        fn validate_no_solution_set(
972            inference: &mut Inference<'_>,
973            canonical_impl: CanonicalImpl,
974            lookup_context: &ImplLookupContext,
975            negative_impls_concrete_traits: impl Iterator<Item = Maybe<ConcreteTraitId>>,
976        ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
977            for concrete_trait_id in negative_impls_concrete_traits {
978                let concrete_trait_id = concrete_trait_id.map_err(|diag_added| {
979                    inference.set_error(InferenceError::Reported(diag_added))
980                })?;
981                for garg in concrete_trait_id.generic_args(inference.db) {
982                    let GenericArgumentId::Type(ty) = garg else {
983                        continue;
984                    };
985                    let ty = inference.rewrite(ty).no_err();
986                    // If the negative impl has a generic argument that is not fully
987                    // concrete we can't tell if we should rule out the candidate impl.
988                    // For example if we have -TypeEqual<S, T> we can't tell if S and
989                    // T are going to be assigned the same concrete type.
990                    // We return `SolutionSet::Ambiguous` here to indicate that more
991                    // information is needed.
992                    // Closure can only have one type, even if it's not fully concrete, so can use
993                    // it and not get ambiguity.
994                    if !matches!(ty.lookup_intern(inference.db), TypeLongId::Closure(_))
995                        && !ty.is_fully_concrete(inference.db)
996                    {
997                        // TODO(ilya): Try to detect the ambiguity earlier in the
998                        // inference process.
999                        return Ok(SolutionSet::Ambiguous(
1000                            Ambiguity::NegativeImplWithUnresolvedGenericArgs {
1001                                impl_id: canonical_impl.0,
1002                                ty,
1003                            },
1004                        ));
1005                    }
1006                }
1007
1008                if !matches!(
1009                    inference.trait_solution_set(
1010                        concrete_trait_id,
1011                        ImplVarTraitItemMappings::default(),
1012                        lookup_context.clone()
1013                    )?,
1014                    SolutionSet::None
1015                ) {
1016                    // If a negative impl has an impl, then we should skip it.
1017                    return Ok(SolutionSet::None);
1018                }
1019            }
1020
1021            Ok(SolutionSet::Unique(canonical_impl))
1022        }
1023        match canonical_impl.0.lookup_intern(self.db) {
1024            ImplLongId::Concrete(concrete_impl) => {
1025                let mut rewriter = SubstitutionRewriter {
1026                    db: self.db,
1027                    substitution: &concrete_impl.substitution(self.db).map_err(|diag_added| {
1028                        self.set_error(InferenceError::Reported(diag_added))
1029                    })?,
1030                };
1031                let generic_params = self
1032                    .db
1033                    .impl_def_generic_params(concrete_impl.impl_def_id(self.db))
1034                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1035                let concrete_traits = generic_params
1036                    .iter()
1037                    .filter_map(|generic_param| {
1038                        try_extract_matches!(generic_param, GenericParam::NegImpl)
1039                    })
1040                    .map(|generic_param| {
1041                        rewriter
1042                            .rewrite(generic_param.clone())
1043                            .and_then(|generic_param| generic_param.concrete_trait)
1044                    });
1045                validate_no_solution_set(self, canonical_impl, lookup_context, concrete_traits)
1046            }
1047            ImplLongId::GeneratedImpl(generated_impl) => validate_no_solution_set(
1048                self,
1049                canonical_impl,
1050                lookup_context,
1051                generated_impl
1052                    .lookup_intern(self.db)
1053                    .generic_params
1054                    .iter()
1055                    .filter_map(|generic_param| {
1056                        try_extract_matches!(generic_param, GenericParam::NegImpl)
1057                    })
1058                    .map(|generic_param| generic_param.concrete_trait),
1059            ),
1060            ImplLongId::GenericParameter(_)
1061            | ImplLongId::ImplVar(_)
1062            | ImplLongId::ImplImpl(_)
1063            | ImplLongId::TraitImpl(_) => Ok(SolutionSet::Unique(canonical_impl)),
1064        }
1065    }
1066
1067    // Error handling methods
1068    // ======================
1069
1070    /// Sets an error in the inference state.
1071    /// Does nothing if an error is already set.
1072    /// Returns an `ErrorSet` that can be used in reporting the error.
1073    pub fn set_error(&mut self, err: InferenceError) -> ErrorSet {
1074        if self.error_status.is_err() {
1075            return ErrorSet;
1076        }
1077        self.error_status = if let InferenceError::Reported(diag_added) = err {
1078            self.consumed_error = Some(diag_added);
1079            Err(InferenceErrorStatus::Consumed)
1080        } else {
1081            self.error = Some(err);
1082            Err(InferenceErrorStatus::Pending)
1083        };
1084        ErrorSet
1085    }
1086
1087    /// Returns whether an error is set (either pending or consumed).
1088    pub fn is_error_set(&self) -> InferenceResult<()> {
1089        if self.error_status.is_err() { Err(ErrorSet) } else { Ok(()) }
1090    }
1091
1092    /// Consumes the error but doesn't report it. If there is no error, or the error is consumed,
1093    /// returns None. This should be used with caution. Always prefer to use
1094    /// (1) `report_on_pending_error` if possible, or (2) `consume_reported_error` which is safer.
1095    ///
1096    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1097    pub fn consume_error_without_reporting(&mut self, err_set: ErrorSet) -> Option<InferenceError> {
1098        self.consume_error_inner(err_set, skip_diagnostic())
1099    }
1100
1101    /// Consumes the error that is already reported. If there is no error, or the error is consumed,
1102    /// does nothing. This should be used with caution. Always prefer to use
1103    /// `report_on_pending_error` if possible.
1104    ///
1105    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1106    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1107    pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1108        self.consume_error_inner(err_set, diag_added);
1109    }
1110
1111    /// Consumes the error and returns it, but doesn't report it. If there is no error, or the error
1112    /// is already consumed, returns None. This should be used with caution. Always prefer to use
1113    /// `report_on_pending_error` if possible.
1114    ///
1115    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1116    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1117    fn consume_error_inner(
1118        &mut self,
1119        _err_set: ErrorSet,
1120        diag_added: DiagnosticAdded,
1121    ) -> Option<InferenceError> {
1122        if self.error_status != Err(InferenceErrorStatus::Pending) {
1123            return None;
1124            // panic!("consume_error when there is no pending error");
1125        }
1126        self.error_status = Err(InferenceErrorStatus::Consumed);
1127        self.consumed_error = Some(diag_added);
1128        mem::take(&mut self.error)
1129    }
1130
1131    /// Consumes the pending error, if any, and reports it.
1132    /// Should only be called when an error is set, otherwise it panics.
1133    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1134    /// If an error was set but it's already consumed, it doesn't report it again but returns the
1135    /// stored `DiagnosticAdded`.
1136    pub fn report_on_pending_error(
1137        &mut self,
1138        _err_set: ErrorSet,
1139        diagnostics: &mut SemanticDiagnostics,
1140        stable_ptr: SyntaxStablePtrId,
1141    ) -> DiagnosticAdded {
1142        let Err(state_error) = self.error_status else {
1143            panic!("report_on_pending_error should be called only on error");
1144        };
1145        match state_error {
1146            InferenceErrorStatus::Consumed => self
1147                .consumed_error
1148                .expect("consumed_error is not set although error_status is Err(Consumed)"),
1149            InferenceErrorStatus::Pending => {
1150                let diag_added = match mem::take(&mut self.error)
1151                    .expect("error is not set although error_status is Err(Pending)")
1152                {
1153                    InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1154                        // If we have other diagnostics, there is no need to TypeNotInferred.
1155
1156                        // Note that `diagnostics` is not empty, so it is safe to return
1157                        // 'DiagnosticAdded' here.
1158                        skip_diagnostic()
1159                    }
1160                    diag => diag.report(diagnostics, stable_ptr),
1161                };
1162
1163                self.error_status = Err(InferenceErrorStatus::Consumed);
1164                self.consumed_error = Some(diag_added);
1165                diag_added
1166            }
1167        }
1168    }
1169
1170    /// If the current status is of a pending error, reports an alternative diagnostic, by calling
1171    /// `report`, and consumes the error. Otherwise, does nothing.
1172    pub fn report_modified_if_pending(
1173        &mut self,
1174        err_set: ErrorSet,
1175        report: impl FnOnce() -> DiagnosticAdded,
1176    ) {
1177        if self.error_status == Err(InferenceErrorStatus::Pending) {
1178            self.consume_reported_error(err_set, report());
1179        }
1180    }
1181}
1182
1183impl<'a> HasDb<&'a dyn SemanticGroup> for Inference<'a> {
1184    fn get_db(&self) -> &'a dyn SemanticGroup {
1185        self.db
1186    }
1187}
1188add_basic_rewrites!(<'a>, Inference<'a>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue);
1189add_expr_rewrites!(<'a>, Inference<'a>, NoError, @exclude);
1190add_rewrite!(<'a>, Inference<'a>, NoError, Ambiguity);
1191impl SemanticRewriter<TypeId, NoError> for Inference<'_> {
1192    fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
1193        if value.is_var_free(self.db) {
1194            return Ok(RewriteResult::NoChange);
1195        }
1196        value.default_rewrite(self)
1197    }
1198}
1199impl SemanticRewriter<ImplId, NoError> for Inference<'_> {
1200    fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
1201        if value.is_var_free(self.db) {
1202            return Ok(RewriteResult::NoChange);
1203        }
1204        value.default_rewrite(self)
1205    }
1206}
1207impl SemanticRewriter<TypeLongId, NoError> for Inference<'_> {
1208    fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
1209        match value {
1210            TypeLongId::Var(var) => {
1211                if let Some(type_id) = self.type_assignment.get(&var.id) {
1212                    let mut long_type_id = type_id.lookup_intern(self.db);
1213                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1214                        *self.type_assignment.get_mut(&var.id).unwrap() =
1215                            long_type_id.clone().intern(self.db);
1216                    }
1217                    *value = long_type_id;
1218                    return Ok(RewriteResult::Modified);
1219                }
1220            }
1221            TypeLongId::ImplType(impl_type_id) => {
1222                if let Some(type_id) = self.impl_type_bounds.get(impl_type_id) {
1223                    *value = type_id.lookup_intern(self.db);
1224                    self.internal_rewrite(value)?;
1225                    return Ok(RewriteResult::Modified);
1226                }
1227                let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1228                let impl_id = impl_type_id.impl_id();
1229                let trait_ty = impl_type_id.ty();
1230                return Ok(match impl_id.lookup_intern(self.db) {
1231                    ImplLongId::GenericParameter(_) | ImplLongId::TraitImpl(_) => {
1232                        impl_type_id_rewrite_result
1233                    }
1234                    ImplLongId::ImplImpl(impl_impl) => {
1235                        // The grand parent impl must be var free since we are rewriting the parent,
1236                        // and the parent is not var.
1237                        assert!(impl_impl.impl_id().is_var_free(self.db));
1238                        impl_type_id_rewrite_result
1239                    }
1240                    ImplLongId::Concrete(_) => {
1241                        if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1242                            impl_id, trait_ty, self.db,
1243                        )) {
1244                            *value = self.rewrite(ty).no_err().lookup_intern(self.db);
1245                            RewriteResult::Modified
1246                        } else {
1247                            impl_type_id_rewrite_result
1248                        }
1249                    }
1250                    ImplLongId::ImplVar(var) => {
1251                        *value = self.rewritten_impl_type(var, trait_ty).lookup_intern(self.db);
1252                        return Ok(RewriteResult::Modified);
1253                    }
1254                    ImplLongId::GeneratedImpl(generated) => {
1255                        *value = self
1256                            .rewrite(
1257                                *generated
1258                                    .lookup_intern(self.db)
1259                                    .impl_items
1260                                    .0
1261                                    .get(&impl_type_id.ty())
1262                                    .unwrap(),
1263                            )
1264                            .no_err()
1265                            .lookup_intern(self.db);
1266                        RewriteResult::Modified
1267                    }
1268                });
1269            }
1270            _ => {}
1271        }
1272        value.default_rewrite(self)
1273    }
1274}
1275impl SemanticRewriter<ConstValue, NoError> for Inference<'_> {
1276    fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
1277        match value {
1278            ConstValue::Var(var, _) => {
1279                return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1280                    let mut const_value = const_value_id.lookup_intern(self.db);
1281                    if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1282                        *self.const_assignment.get_mut(&var.id).unwrap() =
1283                            const_value.clone().intern(self.db);
1284                    }
1285                    *value = const_value;
1286                    RewriteResult::Modified
1287                } else {
1288                    RewriteResult::NoChange
1289                });
1290            }
1291            ConstValue::ImplConstant(impl_constant_id) => {
1292                let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1293                let impl_id = impl_constant_id.impl_id();
1294                let trait_constant = impl_constant_id.trait_constant_id();
1295                return Ok(match impl_id.lookup_intern(self.db) {
1296                    ImplLongId::GenericParameter(_)
1297                    | ImplLongId::TraitImpl(_)
1298                    | ImplLongId::GeneratedImpl(_) => impl_constant_id_rewrite_result,
1299                    ImplLongId::ImplImpl(impl_impl) => {
1300                        // The grand parent impl must be var free since we are rewriting the parent,
1301                        // and the parent is not var.
1302                        assert!(impl_impl.impl_id().is_var_free(self.db));
1303                        impl_constant_id_rewrite_result
1304                    }
1305                    ImplLongId::Concrete(_) => {
1306                        if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1307                            ImplConstantId::new(impl_id, trait_constant, self.db),
1308                        ) {
1309                            *value = self.rewrite(constant).no_err().lookup_intern(self.db);
1310                            RewriteResult::Modified
1311                        } else {
1312                            impl_constant_id_rewrite_result
1313                        }
1314                    }
1315                    ImplLongId::ImplVar(var) => {
1316                        *value = self
1317                            .rewritten_impl_constant(var, trait_constant)
1318                            .lookup_intern(self.db);
1319                        return Ok(RewriteResult::Modified);
1320                    }
1321                });
1322            }
1323            _ => {}
1324        }
1325        value.default_rewrite(self)
1326    }
1327}
1328impl SemanticRewriter<ImplLongId, NoError> for Inference<'_> {
1329    fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
1330        match value {
1331            ImplLongId::ImplVar(var) => {
1332                let long_id = var.lookup_intern(self.db);
1333                // Relax the candidates.
1334                let impl_var_id = long_id.id;
1335                if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1336                    let mut long_impl_id = impl_id.lookup_intern(self.db);
1337                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1338                        *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1339                            long_impl_id.clone().intern(self.db);
1340                    }
1341                    *value = long_impl_id;
1342                    return Ok(RewriteResult::Modified);
1343                }
1344            }
1345            ImplLongId::ImplImpl(impl_impl_id) => {
1346                let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1347                let impl_id = impl_impl_id.impl_id();
1348                return Ok(match impl_id.lookup_intern(self.db) {
1349                    ImplLongId::GenericParameter(_)
1350                    | ImplLongId::TraitImpl(_)
1351                    | ImplLongId::GeneratedImpl(_) => impl_impl_id_rewrite_result,
1352                    ImplLongId::ImplImpl(impl_impl) => {
1353                        // The grand parent impl must be var free since we are rewriting the parent,
1354                        // and the parent is not var.
1355                        assert!(impl_impl.impl_id().is_var_free(self.db));
1356                        impl_impl_id_rewrite_result
1357                    }
1358                    ImplLongId::Concrete(_) => {
1359                        if let Ok(ty) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1360                            *value = self.rewrite(ty).no_err().lookup_intern(self.db);
1361                            RewriteResult::Modified
1362                        } else {
1363                            impl_impl_id_rewrite_result
1364                        }
1365                    }
1366                    ImplLongId::ImplVar(var) => {
1367                        if let Ok(concrete_trait_impl) =
1368                            impl_impl_id.concrete_trait_impl_id(self.db)
1369                        {
1370                            *value = self
1371                                .rewritten_impl_impl(var, concrete_trait_impl)
1372                                .lookup_intern(self.db);
1373                            return Ok(RewriteResult::Modified);
1374                        } else {
1375                            impl_impl_id_rewrite_result
1376                        }
1377                    }
1378                });
1379            }
1380
1381            _ => {}
1382        }
1383        if value.is_var_free(self.db) {
1384            return Ok(RewriteResult::NoChange);
1385        }
1386        value.default_rewrite(self)
1387    }
1388}
1389
1390struct InferenceIdReplacer<'a> {
1391    db: &'a dyn SemanticGroup,
1392    from_inference_id: InferenceId,
1393    to_inference_id: InferenceId,
1394}
1395impl<'a> InferenceIdReplacer<'a> {
1396    fn new(
1397        db: &'a dyn SemanticGroup,
1398        from_inference_id: InferenceId,
1399        to_inference_id: InferenceId,
1400    ) -> Self {
1401        Self { db, from_inference_id, to_inference_id }
1402    }
1403}
1404impl<'a> HasDb<&'a dyn SemanticGroup> for InferenceIdReplacer<'a> {
1405    fn get_db(&self) -> &'a dyn SemanticGroup {
1406        self.db
1407    }
1408}
1409add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1410add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1411add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity);
1412impl SemanticRewriter<InferenceId, NoError> for InferenceIdReplacer<'_> {
1413    fn internal_rewrite(&mut self, value: &mut InferenceId) -> Result<RewriteResult, NoError> {
1414        if value == &self.from_inference_id {
1415            *value = self.to_inference_id;
1416            Ok(RewriteResult::Modified)
1417        } else {
1418            Ok(RewriteResult::NoChange)
1419        }
1420    }
1421}