cairo_lang_semantic/expr/
inference.rs

1//! Bidirectional type inference.
2
3use std::collections::{BTreeMap, HashMap, VecDeque};
4use std::hash::Hash;
5use std::mem;
6use std::ops::{Deref, DerefMut};
7use std::sync::Arc;
8
9use cairo_lang_debug::DebugWithDb;
10use cairo_lang_defs::ids::{
11    ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId,
12    GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LanguageElementId,
13    LocalVarId, LookupItemId, MemberId, NamedLanguageElementId, ParamId, StructId, TraitConstantId,
14    TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
15};
16use cairo_lang_diagnostics::{DiagnosticAdded, Maybe, skip_diagnostic};
17use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
18use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
19use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
20use cairo_lang_utils::{
21    Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches,
22};
23
24use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
25use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
26use crate::db::SemanticGroup;
27use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
28use crate::expr::inference::canonic::ResultNoErrEx;
29use crate::expr::inference::conform::InferenceConform;
30use crate::expr::objects::*;
31use crate::expr::pattern::*;
32use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
33use crate::items::functions::{
34    ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
35    GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
36    ImplGenericFunctionWithBodyId,
37};
38use crate::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType};
39use crate::items::imp::{
40    GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
41    ImplLookupContext, UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
42};
43use crate::items::trt::{
44    ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId, ConcreteTraitTypeId,
45    ConcreteTraitTypeLongId,
46};
47use crate::substitution::{HasDb, RewriteResult, SemanticRewriter};
48use crate::types::{
49    ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
50    ImplTypeById, ImplTypeId,
51};
52use crate::{
53    ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
54    ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
55    FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
56    Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
57    add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
58};
59
60pub mod canonic;
61pub mod conform;
62pub mod infers;
63pub mod solver;
64
65/// A type variable, created when a generic type argument is not passed, and thus is not known
66/// yet and needs to be inferred.
67#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
68pub struct TypeVar {
69    pub inference_id: InferenceId,
70    pub id: LocalTypeVarId,
71}
72
73/// A const variable, created when a generic const argument is not passed, and thus is not known
74/// yet and needs to be inferred.
75#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
76pub struct ConstVar {
77    pub inference_id: InferenceId,
78    pub id: LocalConstVarId,
79}
80
81/// An id for an inference context. Each inference variable is associated with an inference id.
82#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
83#[debug_db(dyn SemanticGroup + 'static)]
84pub enum InferenceId {
85    LookupItemDeclaration(LookupItemId),
86    LookupItemGenerics(LookupItemId),
87    LookupItemDefinition(LookupItemId),
88    ImplDefTrait(ImplDefId),
89    ImplAliasImplDef(ImplAliasId),
90    GenericParam(GenericParamId),
91    GenericImplParamTrait(GenericParamId),
92    GlobalUseStar(GlobalUseId),
93    Canonical,
94    /// For resolving that will not be used anywhere in the semantic model.
95    NoContext,
96}
97
98/// An impl variable, created when a generic type argument is not passed, and thus is not known
99/// yet and needs to be inferred.
100#[derive(Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
101#[debug_db(dyn SemanticGroup + 'static)]
102pub struct ImplVar {
103    pub inference_id: InferenceId,
104    #[dont_rewrite]
105    pub id: LocalImplVarId,
106    pub concrete_trait_id: ConcreteTraitId,
107    #[dont_rewrite]
108    pub lookup_context: ImplLookupContext,
109}
110impl ImplVar {
111    pub fn intern(&self, db: &dyn SemanticGroup) -> ImplVarId {
112        self.clone().intern(db)
113    }
114}
115
116#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
117pub struct LocalTypeVarId(pub usize);
118#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
119pub struct LocalImplVarId(pub usize);
120
121#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
122pub struct LocalConstVarId(pub usize);
123
124define_short_id!(ImplVarId, ImplVar, SemanticGroup, lookup_intern_impl_var, intern_impl_var);
125impl ImplVarId {
126    pub fn id(&self, db: &dyn SemanticGroup) -> LocalImplVarId {
127        self.lookup_intern(db).id
128    }
129    pub fn concrete_trait_id(&self, db: &dyn SemanticGroup) -> ConcreteTraitId {
130        self.lookup_intern(db).concrete_trait_id
131    }
132    pub fn lookup_context(&self, db: &dyn SemanticGroup) -> ImplLookupContext {
133        self.lookup_intern(db).lookup_context
134    }
135}
136semantic_object_for_id!(ImplVarId, lookup_intern_impl_var, intern_impl_var, ImplVar);
137
138#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject)]
139pub enum InferenceVar {
140    Type(LocalTypeVarId),
141    Const(LocalConstVarId),
142    Impl(LocalImplVarId),
143}
144
145// TODO(spapini): Add to diagnostics.
146#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb)]
147#[debug_db(dyn SemanticGroup + 'static)]
148pub enum InferenceError {
149    /// An inference error wrapping a previously reported error.
150    Reported(DiagnosticAdded),
151    Cycle(InferenceVar),
152    TypeKindMismatch {
153        ty0: TypeId,
154        ty1: TypeId,
155    },
156    ConstKindMismatch {
157        const0: ConstValueId,
158        const1: ConstValueId,
159    },
160    ImplKindMismatch {
161        impl0: ImplId,
162        impl1: ImplId,
163    },
164    GenericArgMismatch {
165        garg0: GenericArgumentId,
166        garg1: GenericArgumentId,
167    },
168    TraitMismatch {
169        trt0: TraitId,
170        trt1: TraitId,
171    },
172    ImplTypeMismatch {
173        impl_id: ImplId,
174        trait_type_id: TraitTypeId,
175        ty0: TypeId,
176        ty1: TypeId,
177    },
178    GenericFunctionMismatch {
179        func0: GenericFunctionId,
180        func1: GenericFunctionId,
181    },
182    ConstInferenceNotSupported,
183
184    // TODO(spapini): These are only used for external interface. Separate them along with the
185    // finalize() function to a wrapper.
186    NoImplsFound(ConcreteTraitId),
187    Ambiguity(Ambiguity),
188    TypeNotInferred(TypeId),
189}
190impl InferenceError {
191    pub fn format(&self, db: &(dyn SemanticGroup + 'static)) -> String {
192        match self {
193            InferenceError::Reported(_) => "Inference error occurred.".into(),
194            InferenceError::Cycle(_var) => "Inference cycle detected".into(),
195            InferenceError::TypeKindMismatch { ty0, ty1 } => {
196                format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
197            }
198            InferenceError::ConstKindMismatch { const0, const1 } => {
199                format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
200            }
201            InferenceError::ImplKindMismatch { impl0, impl1 } => {
202                format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
203            }
204            InferenceError::GenericArgMismatch { garg0, garg1 } => {
205                format!(
206                    "Generic arg mismatch: `{:?}` and `{:?}`.",
207                    garg0.debug(db),
208                    garg1.debug(db)
209                )
210            }
211            InferenceError::TraitMismatch { trt0, trt1 } => {
212                format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
213            }
214            InferenceError::ConstInferenceNotSupported => {
215                "Const generic inference not yet supported.".into()
216            }
217            InferenceError::NoImplsFound(concrete_trait_id) => {
218                let info = db.core_info();
219                let trait_id = concrete_trait_id.trait_id(db);
220                if trait_id == info.numeric_literal_trt {
221                    let generic_type = extract_matches!(
222                        concrete_trait_id.generic_args(db)[0],
223                        GenericArgumentId::Type
224                    );
225                    return format!(
226                        "Mismatched types. The type `{:?}` cannot be created from a numeric \
227                         literal.",
228                        generic_type.debug(db)
229                    );
230                } else if trait_id == info.string_literal_trt {
231                    let generic_type = extract_matches!(
232                        concrete_trait_id.generic_args(db)[0],
233                        GenericArgumentId::Type
234                    );
235                    return format!(
236                        "Mismatched types. The type `{:?}` cannot be created from a string \
237                         literal.",
238                        generic_type.debug(db)
239                    );
240                }
241                format!(
242                    "Trait has no implementation in context: {:?}.",
243                    concrete_trait_id.debug(db)
244                )
245            }
246            InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
247            InferenceError::TypeNotInferred(ty) => {
248                format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
249            }
250            InferenceError::GenericFunctionMismatch { func0, func1 } => {
251                format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
252            }
253            InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 } => {
254                format!(
255                    "`{}::{}` type mismatch: `{:?}` and `{:?}`.",
256                    impl_id.format(db),
257                    trait_type_id.name(db),
258                    ty0.debug(db),
259                    ty1.debug(db)
260                )
261            }
262        }
263    }
264}
265
266impl InferenceError {
267    pub fn report(
268        &self,
269        diagnostics: &mut SemanticDiagnostics,
270        stable_ptr: SyntaxStablePtrId,
271    ) -> DiagnosticAdded {
272        match self {
273            InferenceError::Reported(diagnostic_added) => *diagnostic_added,
274            _ => diagnostics
275                .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
276        }
277    }
278}
279
280/// This struct is used to ensure that when an inference error occurs, it is properly set in the
281/// `Inference` object, and then properly consumed.
282///
283/// It must not be constructed directly. Instead, it is returned by [Inference::set_error].
284#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
285pub struct ErrorSet;
286
287pub type InferenceResult<T> = Result<T, ErrorSet>;
288
289#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
290pub enum InferenceErrorStatus {
291    Pending,
292    Consumed,
293}
294
295/// A mapping of an impl var's trait items to concrete items
296#[derive(Debug, Default, PartialEq, Eq, Clone, SemanticObject)]
297pub struct ImplVarTraitItemMappings {
298    /// The trait types of the impl var.
299    types: OrderedHashMap<TraitTypeId, TypeId>,
300    /// The trait constants of the impl var.
301    constants: OrderedHashMap<TraitConstantId, ConstValueId>,
302    /// The trait impls of the impl var.
303    impls: OrderedHashMap<TraitImplId, ImplId>,
304}
305impl Hash for ImplVarTraitItemMappings {
306    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
307        self.types.iter().for_each(|(trait_type_id, type_id)| {
308            trait_type_id.hash(state);
309            type_id.hash(state);
310        });
311        self.constants.iter().for_each(|(trait_const_id, const_id)| {
312            trait_const_id.hash(state);
313            const_id.hash(state);
314        });
315        self.impls.iter().for_each(|(trait_impl_id, impl_id)| {
316            trait_impl_id.hash(state);
317            impl_id.hash(state);
318        });
319    }
320}
321
322/// State of inference.
323#[derive(Debug, DebugWithDb, PartialEq, Eq)]
324#[debug_db(dyn SemanticGroup + 'static)]
325pub struct InferenceData {
326    pub inference_id: InferenceId,
327    /// Current inferred assignment for type variables.
328    pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId>,
329    /// Current inferred assignment for const variables.
330    pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId>,
331    /// Current inferred assignment for impl variables.
332    pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId>,
333    /// Unsolved impl variables mapping to a maps of trait items to a corresponding item variable.
334    /// Upon solution of the trait conforms the fully known item to the variable.
335    pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings>,
336    /// Type variables.
337    pub type_vars: Vec<TypeVar>,
338    /// Const variables.
339    pub const_vars: Vec<ConstVar>,
340    /// Impl variables.
341    pub impl_vars: Vec<ImplVar>,
342    /// Mapping from variables to stable pointers, if exist.
343    pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId>,
344    /// Inference variables that are pending to be solved.
345    pending: VecDeque<LocalImplVarId>,
346    /// Inference variables that have been refuted - no solutions exist.
347    refuted: Vec<LocalImplVarId>,
348    /// Inference variables that have been solved.
349    solved: Vec<LocalImplVarId>,
350    /// Inference variables that are currently ambiguous. May be solved later.
351    ambiguous: Vec<(LocalImplVarId, Ambiguity)>,
352    /// Mapping from impl types to type variables.
353    pub impl_type_bounds: Arc<BTreeMap<ImplTypeById, TypeId>>,
354
355    // Error handling members.
356    /// The current error status.
357    pub error_status: Result<(), InferenceErrorStatus>,
358    /// `Some` only when error_state is Err(Pending).
359    error: Option<InferenceError>,
360    /// `Some` only when error_state is Err(Consumed).
361    consumed_error: Option<DiagnosticAdded>,
362}
363impl InferenceData {
364    pub fn new(inference_id: InferenceId) -> Self {
365        Self {
366            inference_id,
367            type_assignment: OrderedHashMap::default(),
368            impl_assignment: OrderedHashMap::default(),
369            const_assignment: OrderedHashMap::default(),
370            impl_vars_trait_item_mappings: HashMap::new(),
371            type_vars: Vec::new(),
372            impl_vars: Vec::new(),
373            const_vars: Vec::new(),
374            stable_ptrs: HashMap::new(),
375            pending: VecDeque::new(),
376            refuted: Vec::new(),
377            solved: Vec::new(),
378            ambiguous: Vec::new(),
379            impl_type_bounds: Default::default(),
380            error_status: Ok(()),
381            error: None,
382            consumed_error: None,
383        }
384    }
385    pub fn inference<'db, 'b: 'db>(&'db mut self, db: &'b dyn SemanticGroup) -> Inference<'db> {
386        Inference::new(db, self)
387    }
388    pub fn clone_with_inference_id(
389        &self,
390        db: &dyn SemanticGroup,
391        inference_id: InferenceId,
392    ) -> InferenceData {
393        let mut inference_id_replacer =
394            InferenceIdReplacer::new(db, self.inference_id, inference_id);
395        Self {
396            inference_id,
397            type_assignment: self
398                .type_assignment
399                .iter()
400                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
401                .collect(),
402            const_assignment: self
403                .const_assignment
404                .iter()
405                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
406                .collect(),
407            impl_assignment: self
408                .impl_assignment
409                .iter()
410                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
411                .collect(),
412            impl_vars_trait_item_mappings: self
413                .impl_vars_trait_item_mappings
414                .iter()
415                .map(|(k, mappings)| {
416                    (
417                        *k,
418                        ImplVarTraitItemMappings {
419                            types: mappings
420                                .types
421                                .iter()
422                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
423                                .collect(),
424                            constants: mappings
425                                .constants
426                                .iter()
427                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
428                                .collect(),
429                            impls: mappings
430                                .impls
431                                .iter()
432                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
433                                .collect(),
434                        },
435                    )
436                })
437                .collect(),
438            type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
439            const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
440            impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
441            stable_ptrs: self.stable_ptrs.clone(),
442            pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
443            refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
444            solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
445            ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
446            // we do not need to rewrite the impl type bounds, as they all should be var free.
447            impl_type_bounds: self.impl_type_bounds.clone(),
448
449            error_status: self.error_status,
450            error: self.error.clone(),
451            consumed_error: self.consumed_error,
452        }
453    }
454    pub fn temporary_clone(&self) -> InferenceData {
455        Self {
456            inference_id: self.inference_id,
457            type_assignment: self.type_assignment.clone(),
458            const_assignment: self.const_assignment.clone(),
459            impl_assignment: self.impl_assignment.clone(),
460            impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
461            type_vars: self.type_vars.clone(),
462            const_vars: self.const_vars.clone(),
463            impl_vars: self.impl_vars.clone(),
464            stable_ptrs: self.stable_ptrs.clone(),
465            pending: self.pending.clone(),
466            refuted: self.refuted.clone(),
467            solved: self.solved.clone(),
468            ambiguous: self.ambiguous.clone(),
469            impl_type_bounds: self.impl_type_bounds.clone(),
470            error_status: self.error_status,
471            error: self.error.clone(),
472            consumed_error: self.consumed_error,
473        }
474    }
475}
476
477/// State of inference. A system of inference constraints.
478pub struct Inference<'db> {
479    db: &'db dyn SemanticGroup,
480    pub data: &'db mut InferenceData,
481}
482
483impl Deref for Inference<'_> {
484    type Target = InferenceData;
485
486    fn deref(&self) -> &Self::Target {
487        self.data
488    }
489}
490impl DerefMut for Inference<'_> {
491    fn deref_mut(&mut self) -> &mut Self::Target {
492        self.data
493    }
494}
495
496impl std::fmt::Debug for Inference<'_> {
497    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
498        let x = self.data.debug(self.db.elongate());
499        write!(f, "{x:?}")
500    }
501}
502
503impl<'db> Inference<'db> {
504    fn new(db: &'db dyn SemanticGroup, data: &'db mut InferenceData) -> Self {
505        Self { db, data }
506    }
507
508    /// Getter for an [ImplVar].
509    fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar {
510        &self.impl_vars[var_id.0]
511    }
512
513    /// Getter for an impl var assignment.
514    pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId> {
515        self.impl_assignment.get(&var_id).copied()
516    }
517
518    /// Getter for a type var assignment.
519    fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId> {
520        self.type_assignment.get(&var_id).copied()
521    }
522
523    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
524    /// Returns a wrapping TypeId.
525    pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeId {
526        let var = self.new_type_var_raw(stable_ptr);
527
528        TypeLongId::Var(var).intern(self.db)
529    }
530
531    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
532    /// Returns the variable id.
533    pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeVar {
534        let var =
535            TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
536        if let Some(stable_ptr) = stable_ptr {
537            self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
538        }
539        self.type_vars.push(var);
540        var
541    }
542
543    /// Sets the infrence's impl type bounds to the given map, and rewrittes the types so all the
544    /// types are var free.
545    pub fn set_impl_type_bounds(&mut self, impl_type_bounds: OrderedHashMap<ImplTypeId, TypeId>) {
546        let impl_type_bounds_finalized = impl_type_bounds
547            .iter()
548            .filter_map(|(impl_type, ty)| {
549                let rewritten_type = self.rewrite(ty.lookup_intern(self.db)).no_err();
550                if !matches!(rewritten_type, TypeLongId::Var(_)) {
551                    return Some(((*impl_type).into(), rewritten_type.intern(self.db)));
552                }
553                // conformed the var type to the original impl type to remove it from the pending
554                // list.
555                self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
556                None
557            })
558            .collect();
559
560        self.data.impl_type_bounds = Arc::new(impl_type_bounds_finalized);
561    }
562
563    /// Allocates a new [ConstVar] for an unknown consts that needs to be inferred.
564    /// Returns a wrapping [ConstValueId].
565    pub fn new_const_var(
566        &mut self,
567        stable_ptr: Option<SyntaxStablePtrId>,
568        ty: TypeId,
569    ) -> ConstValueId {
570        let var = self.new_const_var_raw(stable_ptr);
571        ConstValue::Var(var, ty).intern(self.db)
572    }
573
574    /// Allocates a new [ConstVar] for an unknown type that needs to be inferred.
575    /// Returns the variable id.
576    pub fn new_const_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> ConstVar {
577        let var = ConstVar {
578            inference_id: self.inference_id,
579            id: LocalConstVarId(self.const_vars.len()),
580        };
581        if let Some(stable_ptr) = stable_ptr {
582            self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
583        }
584        self.const_vars.push(var);
585        var
586    }
587
588    /// Allocates a new [ImplVar] for an unknown type that needs to be inferred.
589    /// Returns a wrapping ImplId.
590    pub fn new_impl_var(
591        &mut self,
592        concrete_trait_id: ConcreteTraitId,
593        stable_ptr: Option<SyntaxStablePtrId>,
594        lookup_context: ImplLookupContext,
595    ) -> ImplId {
596        let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
597        ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
598    }
599
600    /// Allocates a new [ImplVar] for an unknown type that needs to be inferred.
601    /// Returns the variable id.
602    fn new_impl_var_raw(
603        &mut self,
604        lookup_context: ImplLookupContext,
605        concrete_trait_id: ConcreteTraitId,
606        stable_ptr: Option<SyntaxStablePtrId>,
607    ) -> LocalImplVarId {
608        let mut lookup_context = lookup_context;
609        lookup_context.insert_module(concrete_trait_id.trait_id(self.db).module_file_id(self.db).0);
610
611        let id = LocalImplVarId(self.impl_vars.len());
612        if let Some(stable_ptr) = stable_ptr {
613            self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
614        }
615        let var =
616            ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
617        self.impl_vars.push(var);
618        self.pending.push_back(id);
619        id
620    }
621
622    /// Solves the inference system. After a successful solve, there are no more pending impl
623    /// inferences.
624    /// Returns whether the inference was successful. If not, the error may be found by
625    /// `.error_state()`.
626    pub fn solve(&mut self) -> InferenceResult<()> {
627        self.solve_ex().map_err(|(err_set, _)| err_set)
628    }
629
630    /// Same as `solve`, but returns the error stable pointer if an error occurred.
631    fn solve_ex(&mut self) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
632        let mut ambiguous = std::mem::take(&mut self.ambiguous);
633        self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
634        while let Some(var) = self.pending.pop_front() {
635            // First inference error stops inference.
636            self.solve_single_pending(var).map_err(|err_set| {
637                (err_set, self.stable_ptrs.get(&InferenceVar::Impl(var)).copied())
638            })?;
639        }
640        Ok(())
641    }
642
643    fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
644        if self.impl_assignment.contains_key(&var) {
645            return Ok(());
646        }
647        let solution = match self.impl_var_solution_set(var)? {
648            SolutionSet::None => {
649                self.refuted.push(var);
650                return Ok(());
651            }
652            SolutionSet::Ambiguous(ambiguity) => {
653                self.ambiguous.push((var, ambiguity));
654                return Ok(());
655            }
656            SolutionSet::Unique(solution) => solution,
657        };
658
659        // Solution found. Assign it.
660        self.assign_local_impl(var, solution)?;
661
662        // Something changed.
663        self.solved.push(var);
664        let mut ambiguous = std::mem::take(&mut self.ambiguous);
665        self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
666
667        Ok(())
668    }
669
670    /// Returns the solution set status for the inference:
671    /// Whether there is a unique solution, multiple solutions, no solutions or an error.
672    pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<()>> {
673        self.solve()?;
674        if !self.refuted.is_empty() {
675            return Ok(SolutionSet::None);
676        }
677        if let Some((_, ambiguity)) = self.ambiguous.first() {
678            return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
679        }
680        assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
681        Ok(SolutionSet::Unique(()))
682    }
683
684    /// Finalizes the inference by inferring uninferred numeric literals as felt252.
685    /// Returns an error and does not report it.
686    pub fn finalize_without_reporting(
687        &mut self,
688    ) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
689        if self.error_status.is_err() {
690            // TODO(yuval): consider adding error location to the set error.
691            return Err((ErrorSet, None));
692        }
693        let info = self.db.core_info();
694        let numeric_trait_id = info.numeric_literal_trt;
695        let felt_ty = info.felt252;
696
697        // Conform all uninferred numeric literals to felt252.
698        loop {
699            let mut changed = false;
700            self.solve_ex()?;
701            for (var, _) in self.ambiguous.clone() {
702                let impl_var = self.impl_var(var).clone();
703                if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
704                    continue;
705                }
706                // Uninferred numeric trait. Resolve as felt252.
707                let ty = extract_matches!(
708                    impl_var.concrete_trait_id.generic_args(self.db)[0],
709                    GenericArgumentId::Type
710                );
711                if self.rewrite(ty).no_err() == felt_ty {
712                    continue;
713                }
714                self.conform_ty(ty, felt_ty).map_err(|err_set| {
715                    (err_set, self.stable_ptrs.get(&InferenceVar::Impl(impl_var.id)).copied())
716                })?;
717                changed = true;
718                break;
719            }
720            if !changed {
721                break;
722            }
723        }
724        assert!(
725            self.pending.is_empty(),
726            "pending should all be solved by this point. Guaranteed by solve()."
727        );
728
729        let Some((var, err)) = self.first_undetermined_variable() else {
730            return Ok(());
731        };
732        Err((self.set_error(err), self.stable_ptrs.get(&var).copied()))
733    }
734
735    /// Finalizes the inference and report diagnostics if there are any errors.
736    /// All the remaining type vars are mapped to the `missing` type, to prevent additional
737    /// diagnostics.
738    pub fn finalize(
739        &mut self,
740        diagnostics: &mut SemanticDiagnostics,
741        stable_ptr: SyntaxStablePtrId,
742    ) {
743        if let Err((err_set, err_stable_ptr)) = self.finalize_without_reporting() {
744            let diag = self.report_on_pending_error(
745                err_set,
746                diagnostics,
747                err_stable_ptr.unwrap_or(stable_ptr),
748            );
749
750            let ty_missing = TypeId::missing(self.db, diag);
751            for var in &self.data.type_vars {
752                self.data.type_assignment.entry(var.id).or_insert(ty_missing);
753            }
754        }
755    }
756
757    /// Retrieves the first variable that is still not inferred, or None, if everything is
758    /// inferred.
759    /// Does not set the error but return it, which is ok as this is a private helper function.
760    fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError)> {
761        if let Some(var) = self.refuted.first().copied() {
762            let impl_var = self.impl_var(var).clone();
763            let concrete_trait_id = impl_var.concrete_trait_id;
764            let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
765            return Some((
766                InferenceVar::Impl(var),
767                InferenceError::NoImplsFound(concrete_trait_id),
768            ));
769        }
770        let mut fallback_ret = None;
771        if let Some((var, ambiguity)) = self.ambiguous.first() {
772            // Note: do not rewrite `ambiguity`, since it is expressed in canonical variables.
773            let ret =
774                Some((InferenceVar::Impl(*var), InferenceError::Ambiguity(ambiguity.clone())));
775            if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
776                return ret;
777            } else {
778                fallback_ret = ret;
779            }
780        }
781        for (id, var) in self.type_vars.iter().enumerate() {
782            if self.type_assignment(LocalTypeVarId(id)).is_none() {
783                let ty = TypeLongId::Var(*var).intern(self.db);
784                return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
785            }
786        }
787        fallback_ret
788    }
789
790    /// Assigns a value to a local impl variable id. See assign_impl().
791    fn assign_local_impl(
792        &mut self,
793        var: LocalImplVarId,
794        impl_id: ImplId,
795    ) -> InferenceResult<ImplId> {
796        let concrete_trait = impl_id
797            .concrete_trait(self.db)
798            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
799        self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
800        if let Some(other_impl) = self.impl_assignment(var) {
801            return self.conform_impl(impl_id, other_impl);
802        }
803        if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
804        {
805            return Err(self.set_error(InferenceError::Cycle(InferenceVar::Impl(var))));
806        }
807        self.impl_assignment.insert(var, impl_id);
808        if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
809            for (trait_type_id, ty) in mappings.types {
810                let impl_ty = self
811                    .db
812                    .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_type_id, self.db))
813                    .map_err(|_| ErrorSet)?;
814                if let Err(err_set) = self.conform_ty(ty, impl_ty) {
815                    // Override the error with ImplTypeMismatch.
816                    let ty0 = self.rewrite(ty).no_err();
817                    let ty1 = self.rewrite(impl_ty).no_err();
818
819                    self.error =
820                        Some(InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 });
821                    return Err(err_set);
822                }
823            }
824            for (trait_constant, constant_id) in mappings.constants {
825                self.conform_const(
826                    constant_id,
827                    self.db
828                        .impl_constant_concrete_implized_value(ImplConstantId::new(
829                            impl_id,
830                            trait_constant,
831                            self.db,
832                        ))
833                        .map_err(|_| ErrorSet)?,
834                )?;
835            }
836            for (trait_impl, inner_impl_id) in mappings.impls {
837                self.conform_impl(
838                    inner_impl_id,
839                    self.db
840                        .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
841                        .map_err(|_| ErrorSet)?,
842                )?;
843            }
844        }
845        Ok(impl_id)
846    }
847
848    /// Tries to assigns value to an [ImplVarId]. Return the assigned impl, or an error.
849    fn assign_impl(&mut self, var_id: ImplVarId, impl_id: ImplId) -> InferenceResult<ImplId> {
850        let var = var_id.lookup_intern(self.db);
851        if var.inference_id != self.inference_id {
852            return Err(self.set_error(InferenceError::ImplKindMismatch {
853                impl0: ImplLongId::ImplVar(var_id).intern(self.db),
854                impl1: impl_id,
855            }));
856        }
857        self.assign_local_impl(var.id, impl_id)
858    }
859
860    /// Assigns a value to a [TypeVar]. Return the assigned type, or an error.
861    /// Assumes the variable is not already assigned.
862    fn assign_ty(&mut self, var: TypeVar, ty: TypeId) -> InferenceResult<TypeId> {
863        if var.inference_id != self.inference_id {
864            return Err(self.set_error(InferenceError::TypeKindMismatch {
865                ty0: TypeLongId::Var(var).intern(self.db),
866                ty1: ty,
867            }));
868        }
869        assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
870        let inference_var = InferenceVar::Type(var.id);
871        if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
872            return Err(self.set_error(InferenceError::Cycle(inference_var)));
873        }
874        // If assigning var to var - making sure assigning to the lower id for proper canonization.
875        if let TypeLongId::Var(other) = ty.lookup_intern(self.db) {
876            if other.inference_id == self.inference_id && other.id.0 > var.id.0 {
877                let var_ty = TypeLongId::Var(var).intern(self.db);
878                self.type_assignment.insert(other.id, var_ty);
879                return Ok(var_ty);
880            }
881        }
882        self.type_assignment.insert(var.id, ty);
883        Ok(ty)
884    }
885
886    /// Assigns a value to a [ConstVar]. Return the assigned const, or an error.
887    /// Assumes the variable is not already assigned.
888    fn assign_const(&mut self, var: ConstVar, id: ConstValueId) -> InferenceResult<ConstValueId> {
889        if var.inference_id != self.inference_id {
890            return Err(self.set_error(InferenceError::ConstKindMismatch {
891                const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
892                    .intern(self.db),
893                const1: id,
894            }));
895        }
896
897        self.const_assignment.insert(var.id, id);
898        Ok(id)
899    }
900
901    /// Computes the solution set for an impl variable with a recursive query.
902    fn impl_var_solution_set(
903        &mut self,
904        var: LocalImplVarId,
905    ) -> InferenceResult<SolutionSet<ImplId>> {
906        let impl_var = self.impl_var(var).clone();
907        // Update the concrete trait of the impl var.
908        let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
909        self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
910        let impl_var_trait_item_mappings =
911            self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
912        let solution_set = self.trait_solution_set(
913            concrete_trait_id,
914            impl_var_trait_item_mappings,
915            impl_var.lookup_context,
916        )?;
917        Ok(match solution_set {
918            SolutionSet::None => SolutionSet::None,
919            SolutionSet::Unique((canonical_impl, canonicalizer)) => {
920                SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
921            }
922            SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
923        })
924    }
925
926    /// Computes the solution set for a trait with a recursive query.
927    pub fn trait_solution_set(
928        &mut self,
929        concrete_trait_id: ConcreteTraitId,
930        impl_var_trait_item_mappings: ImplVarTraitItemMappings,
931        mut lookup_context: ImplLookupContext,
932    ) -> InferenceResult<SolutionSet<(CanonicalImpl, CanonicalMapping)>> {
933        let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
934        // TODO(spapini): This is done twice. Consider doing it only here.
935        let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
936        enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
937
938        // Don't try to resolve impls if the first generic param is a variable.
939        let generic_args = concrete_trait_id.generic_args(self.db);
940        match generic_args.first() {
941            Some(GenericArgumentId::Type(ty)) => {
942                if let TypeLongId::Var(_) = ty.lookup_intern(self.db) {
943                    // Don't try to infer such impls.
944                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
945                }
946            }
947            Some(GenericArgumentId::Impl(imp)) => {
948                // Don't try to infer such impls.
949                if let ImplLongId::ImplVar(_) = imp.lookup_intern(self.db) {
950                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
951                }
952            }
953            Some(GenericArgumentId::Constant(const_value)) => {
954                if let ConstValue::Var(_, _) = const_value.lookup_intern(self.db) {
955                    // Don't try to infer such impls.
956                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
957                }
958            }
959            _ => {}
960        };
961        let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
962            self.db,
963            self.inference_id,
964            concrete_trait_id,
965            impl_var_trait_item_mappings,
966        );
967        // impl_type_bounds order is deterimend by the generic params of the function and therefore
968        // is consistent.
969        let solution_set = match self.db.canonic_trait_solutions(
970            canonical_trait,
971            lookup_context,
972            (*self.data.impl_type_bounds).clone(),
973        ) {
974            Ok(solution_set) => solution_set,
975            Err(err) => return Err(self.set_error(err)),
976        };
977        match solution_set {
978            SolutionSet::None => Ok(SolutionSet::None),
979            SolutionSet::Unique(canonical_impl) => {
980                Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
981            }
982            SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
983        }
984    }
985
986    /// Validate that the given impl is valid based on its negative impls arguments.
987    /// Returns `SolutionSet::Unique(canonical_impl)` if the impl is valid and
988    /// SolutionSet::Ambiguous(...) otherwise.
989    fn validate_neg_impls(
990        &mut self,
991        lookup_context: &ImplLookupContext,
992        canonical_impl: CanonicalImpl,
993    ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
994        /// Validates that no solution set is found for the negative impls.
995        fn validate_no_solution_set(
996            inference: &mut Inference<'_>,
997            canonical_impl: CanonicalImpl,
998            lookup_context: &ImplLookupContext,
999            negative_impls_concrete_traits: impl Iterator<Item = Maybe<ConcreteTraitId>>,
1000        ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
1001            for concrete_trait_id in negative_impls_concrete_traits {
1002                let concrete_trait_id = concrete_trait_id.map_err(|diag_added| {
1003                    inference.set_error(InferenceError::Reported(diag_added))
1004                })?;
1005                for garg in concrete_trait_id.generic_args(inference.db) {
1006                    let GenericArgumentId::Type(ty) = garg else {
1007                        continue;
1008                    };
1009                    let ty = inference.rewrite(ty).no_err();
1010                    // If the negative impl has a generic argument that is not fully
1011                    // concrete we can't tell if we should rule out the candidate impl.
1012                    // For example if we have -TypeEqual<S, T> we can't tell if S and
1013                    // T are going to be assigned the same concrete type.
1014                    // We return `SolutionSet::Ambiguous` here to indicate that more
1015                    // information is needed.
1016                    // Closure can only have one type, even if it's not fully concrete, so can use
1017                    // it and not get ambiguity.
1018                    if !matches!(ty.lookup_intern(inference.db), TypeLongId::Closure(_))
1019                        && !ty.is_fully_concrete(inference.db)
1020                    {
1021                        // TODO(ilya): Try to detect the ambiguity earlier in the
1022                        // inference process.
1023                        return Ok(SolutionSet::Ambiguous(
1024                            Ambiguity::NegativeImplWithUnresolvedGenericArgs {
1025                                impl_id: canonical_impl.0,
1026                                ty,
1027                            },
1028                        ));
1029                    }
1030                }
1031
1032                if !matches!(
1033                    inference.trait_solution_set(
1034                        concrete_trait_id,
1035                        ImplVarTraitItemMappings::default(),
1036                        lookup_context.clone()
1037                    )?,
1038                    SolutionSet::None
1039                ) {
1040                    // If a negative impl has an impl, then we should skip it.
1041                    return Ok(SolutionSet::None);
1042                }
1043            }
1044
1045            Ok(SolutionSet::Unique(canonical_impl))
1046        }
1047        match canonical_impl.0.lookup_intern(self.db) {
1048            ImplLongId::Concrete(concrete_impl) => {
1049                let substitution = concrete_impl
1050                    .substitution(self.db)
1051                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1052                let generic_params = self
1053                    .db
1054                    .impl_def_generic_params(concrete_impl.impl_def_id(self.db))
1055                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1056                let concrete_traits = generic_params
1057                    .iter()
1058                    .filter_map(|generic_param| {
1059                        try_extract_matches!(generic_param, GenericParam::NegImpl)
1060                    })
1061                    .map(|generic_param| {
1062                        substitution
1063                            .substitute(self.db, generic_param.clone())
1064                            .and_then(|generic_param| generic_param.concrete_trait)
1065                    });
1066                validate_no_solution_set(self, canonical_impl, lookup_context, concrete_traits)
1067            }
1068            ImplLongId::GeneratedImpl(generated_impl) => validate_no_solution_set(
1069                self,
1070                canonical_impl,
1071                lookup_context,
1072                generated_impl
1073                    .lookup_intern(self.db)
1074                    .generic_params
1075                    .iter()
1076                    .filter_map(|generic_param| {
1077                        try_extract_matches!(generic_param, GenericParam::NegImpl)
1078                    })
1079                    .map(|generic_param| generic_param.concrete_trait),
1080            ),
1081            ImplLongId::GenericParameter(_)
1082            | ImplLongId::ImplVar(_)
1083            | ImplLongId::ImplImpl(_)
1084            | ImplLongId::SelfImpl(_) => Ok(SolutionSet::Unique(canonical_impl)),
1085        }
1086    }
1087
1088    // Error handling methods
1089    // ======================
1090
1091    /// Sets an error in the inference state.
1092    /// Does nothing if an error is already set.
1093    /// Returns an `ErrorSet` that can be used in reporting the error.
1094    pub fn set_error(&mut self, err: InferenceError) -> ErrorSet {
1095        if self.error_status.is_err() {
1096            return ErrorSet;
1097        }
1098        self.error_status = if let InferenceError::Reported(diag_added) = err {
1099            self.consumed_error = Some(diag_added);
1100            Err(InferenceErrorStatus::Consumed)
1101        } else {
1102            self.error = Some(err);
1103            Err(InferenceErrorStatus::Pending)
1104        };
1105        ErrorSet
1106    }
1107
1108    /// Returns whether an error is set (either pending or consumed).
1109    pub fn is_error_set(&self) -> InferenceResult<()> {
1110        if self.error_status.is_err() { Err(ErrorSet) } else { Ok(()) }
1111    }
1112
1113    /// Consumes the error but doesn't report it. If there is no error, or the error is consumed,
1114    /// returns None. This should be used with caution. Always prefer to use
1115    /// (1) `report_on_pending_error` if possible, or (2) `consume_reported_error` which is safer.
1116    ///
1117    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1118    pub fn consume_error_without_reporting(&mut self, err_set: ErrorSet) -> Option<InferenceError> {
1119        self.consume_error_inner(err_set, skip_diagnostic())
1120    }
1121
1122    /// Consumes the error that is already reported. If there is no error, or the error is consumed,
1123    /// does nothing. This should be used with caution. Always prefer to use
1124    /// `report_on_pending_error` if possible.
1125    ///
1126    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1127    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1128    pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1129        self.consume_error_inner(err_set, diag_added);
1130    }
1131
1132    /// Consumes the error and returns it, but doesn't report it. If there is no error, or the error
1133    /// is already consumed, returns None. This should be used with caution. Always prefer to use
1134    /// `report_on_pending_error` if possible.
1135    ///
1136    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1137    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1138    fn consume_error_inner(
1139        &mut self,
1140        _err_set: ErrorSet,
1141        diag_added: DiagnosticAdded,
1142    ) -> Option<InferenceError> {
1143        if self.error_status != Err(InferenceErrorStatus::Pending) {
1144            return None;
1145            // panic!("consume_error when there is no pending error");
1146        }
1147        self.error_status = Err(InferenceErrorStatus::Consumed);
1148        self.consumed_error = Some(diag_added);
1149        mem::take(&mut self.error)
1150    }
1151
1152    /// Consumes the pending error, if any, and reports it.
1153    /// Should only be called when an error is set, otherwise it panics.
1154    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1155    /// If an error was set but it's already consumed, it doesn't report it again but returns the
1156    /// stored `DiagnosticAdded`.
1157    pub fn report_on_pending_error(
1158        &mut self,
1159        _err_set: ErrorSet,
1160        diagnostics: &mut SemanticDiagnostics,
1161        stable_ptr: SyntaxStablePtrId,
1162    ) -> DiagnosticAdded {
1163        let Err(state_error) = self.error_status else {
1164            panic!("report_on_pending_error should be called only on error");
1165        };
1166        match state_error {
1167            InferenceErrorStatus::Consumed => self
1168                .consumed_error
1169                .expect("consumed_error is not set although error_status is Err(Consumed)"),
1170            InferenceErrorStatus::Pending => {
1171                let diag_added = match mem::take(&mut self.error)
1172                    .expect("error is not set although error_status is Err(Pending)")
1173                {
1174                    InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1175                        // If we have other diagnostics, there is no need to TypeNotInferred.
1176
1177                        // Note that `diagnostics` is not empty, so it is safe to return
1178                        // 'DiagnosticAdded' here.
1179                        skip_diagnostic()
1180                    }
1181                    diag => diag.report(diagnostics, stable_ptr),
1182                };
1183
1184                self.error_status = Err(InferenceErrorStatus::Consumed);
1185                self.consumed_error = Some(diag_added);
1186                diag_added
1187            }
1188        }
1189    }
1190
1191    /// If the current status is of a pending error, reports an alternative diagnostic, by calling
1192    /// `report`, and consumes the error. Otherwise, does nothing.
1193    pub fn report_modified_if_pending(
1194        &mut self,
1195        err_set: ErrorSet,
1196        report: impl FnOnce() -> DiagnosticAdded,
1197    ) {
1198        if self.error_status == Err(InferenceErrorStatus::Pending) {
1199            self.consume_reported_error(err_set, report());
1200        }
1201    }
1202}
1203
1204impl<'a> HasDb<&'a dyn SemanticGroup> for Inference<'a> {
1205    fn get_db(&self) -> &'a dyn SemanticGroup {
1206        self.db
1207    }
1208}
1209add_basic_rewrites!(<'a>, Inference<'a>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue);
1210add_expr_rewrites!(<'a>, Inference<'a>, NoError, @exclude);
1211add_rewrite!(<'a>, Inference<'a>, NoError, Ambiguity);
1212impl SemanticRewriter<TypeId, NoError> for Inference<'_> {
1213    fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
1214        if value.is_var_free(self.db) {
1215            return Ok(RewriteResult::NoChange);
1216        }
1217        value.default_rewrite(self)
1218    }
1219}
1220impl SemanticRewriter<ImplId, NoError> for Inference<'_> {
1221    fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
1222        if value.is_var_free(self.db) {
1223            return Ok(RewriteResult::NoChange);
1224        }
1225        value.default_rewrite(self)
1226    }
1227}
1228impl SemanticRewriter<TypeLongId, NoError> for Inference<'_> {
1229    fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
1230        match value {
1231            TypeLongId::Var(var) => {
1232                if let Some(type_id) = self.type_assignment.get(&var.id) {
1233                    let mut long_type_id = type_id.lookup_intern(self.db);
1234                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1235                        *self.type_assignment.get_mut(&var.id).unwrap() =
1236                            long_type_id.clone().intern(self.db);
1237                    }
1238                    *value = long_type_id;
1239                    return Ok(RewriteResult::Modified);
1240                }
1241            }
1242            TypeLongId::ImplType(impl_type_id) => {
1243                if let Some(type_id) = self.impl_type_bounds.get(&((*impl_type_id).into())) {
1244                    *value = type_id.lookup_intern(self.db);
1245                    self.internal_rewrite(value)?;
1246                    return Ok(RewriteResult::Modified);
1247                }
1248                let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1249                let impl_id = impl_type_id.impl_id();
1250                let trait_ty = impl_type_id.ty();
1251                return Ok(match impl_id.lookup_intern(self.db) {
1252                    ImplLongId::GenericParameter(_)
1253                    | ImplLongId::SelfImpl(_)
1254                    | ImplLongId::ImplImpl(_) => impl_type_id_rewrite_result,
1255                    ImplLongId::Concrete(_) => {
1256                        if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1257                            impl_id, trait_ty, self.db,
1258                        )) {
1259                            *value = self.rewrite(ty).no_err().lookup_intern(self.db);
1260                            RewriteResult::Modified
1261                        } else {
1262                            impl_type_id_rewrite_result
1263                        }
1264                    }
1265                    ImplLongId::ImplVar(var) => {
1266                        *value = self.rewritten_impl_type(var, trait_ty).lookup_intern(self.db);
1267                        return Ok(RewriteResult::Modified);
1268                    }
1269                    ImplLongId::GeneratedImpl(generated) => {
1270                        *value = self
1271                            .rewrite(
1272                                *generated
1273                                    .lookup_intern(self.db)
1274                                    .impl_items
1275                                    .0
1276                                    .get(&impl_type_id.ty())
1277                                    .unwrap(),
1278                            )
1279                            .no_err()
1280                            .lookup_intern(self.db);
1281                        RewriteResult::Modified
1282                    }
1283                });
1284            }
1285            _ => {}
1286        }
1287        value.default_rewrite(self)
1288    }
1289}
1290impl SemanticRewriter<ConstValue, NoError> for Inference<'_> {
1291    fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
1292        match value {
1293            ConstValue::Var(var, _) => {
1294                return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1295                    let mut const_value = const_value_id.lookup_intern(self.db);
1296                    if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1297                        *self.const_assignment.get_mut(&var.id).unwrap() =
1298                            const_value.clone().intern(self.db);
1299                    }
1300                    *value = const_value;
1301                    RewriteResult::Modified
1302                } else {
1303                    RewriteResult::NoChange
1304                });
1305            }
1306            ConstValue::ImplConstant(impl_constant_id) => {
1307                let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1308                let impl_id = impl_constant_id.impl_id();
1309                let trait_constant = impl_constant_id.trait_constant_id();
1310                return Ok(match impl_id.lookup_intern(self.db) {
1311                    ImplLongId::GenericParameter(_)
1312                    | ImplLongId::SelfImpl(_)
1313                    | ImplLongId::GeneratedImpl(_)
1314                    | ImplLongId::ImplImpl(_) => impl_constant_id_rewrite_result,
1315                    ImplLongId::Concrete(_) => {
1316                        if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1317                            ImplConstantId::new(impl_id, trait_constant, self.db),
1318                        ) {
1319                            *value = self.rewrite(constant).no_err().lookup_intern(self.db);
1320                            RewriteResult::Modified
1321                        } else {
1322                            impl_constant_id_rewrite_result
1323                        }
1324                    }
1325                    ImplLongId::ImplVar(var) => {
1326                        *value = self
1327                            .rewritten_impl_constant(var, trait_constant)
1328                            .lookup_intern(self.db);
1329                        return Ok(RewriteResult::Modified);
1330                    }
1331                });
1332            }
1333            _ => {}
1334        }
1335        value.default_rewrite(self)
1336    }
1337}
1338impl SemanticRewriter<ImplLongId, NoError> for Inference<'_> {
1339    fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
1340        match value {
1341            ImplLongId::ImplVar(var) => {
1342                let long_id = var.lookup_intern(self.db);
1343                // Relax the candidates.
1344                let impl_var_id = long_id.id;
1345                if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1346                    let mut long_impl_id = impl_id.lookup_intern(self.db);
1347                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1348                        *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1349                            long_impl_id.clone().intern(self.db);
1350                    }
1351                    *value = long_impl_id;
1352                    return Ok(RewriteResult::Modified);
1353                }
1354            }
1355            ImplLongId::ImplImpl(impl_impl_id) => {
1356                let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1357                let impl_id = impl_impl_id.impl_id();
1358                return Ok(match impl_id.lookup_intern(self.db) {
1359                    ImplLongId::GenericParameter(_)
1360                    | ImplLongId::SelfImpl(_)
1361                    | ImplLongId::GeneratedImpl(_)
1362                    | ImplLongId::ImplImpl(_) => impl_impl_id_rewrite_result,
1363                    ImplLongId::Concrete(_) => {
1364                        if let Ok(imp) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1365                            *value = self.rewrite(imp).no_err().lookup_intern(self.db);
1366                            RewriteResult::Modified
1367                        } else {
1368                            impl_impl_id_rewrite_result
1369                        }
1370                    }
1371                    ImplLongId::ImplVar(var) => {
1372                        if let Ok(concrete_trait_impl) =
1373                            impl_impl_id.concrete_trait_impl_id(self.db)
1374                        {
1375                            *value = self
1376                                .rewritten_impl_impl(var, concrete_trait_impl)
1377                                .lookup_intern(self.db);
1378                            return Ok(RewriteResult::Modified);
1379                        } else {
1380                            impl_impl_id_rewrite_result
1381                        }
1382                    }
1383                });
1384            }
1385
1386            _ => {}
1387        }
1388        if value.is_var_free(self.db) {
1389            return Ok(RewriteResult::NoChange);
1390        }
1391        value.default_rewrite(self)
1392    }
1393}
1394
1395struct InferenceIdReplacer<'a> {
1396    db: &'a dyn SemanticGroup,
1397    from_inference_id: InferenceId,
1398    to_inference_id: InferenceId,
1399}
1400impl<'a> InferenceIdReplacer<'a> {
1401    fn new(
1402        db: &'a dyn SemanticGroup,
1403        from_inference_id: InferenceId,
1404        to_inference_id: InferenceId,
1405    ) -> Self {
1406        Self { db, from_inference_id, to_inference_id }
1407    }
1408}
1409impl<'a> HasDb<&'a dyn SemanticGroup> for InferenceIdReplacer<'a> {
1410    fn get_db(&self) -> &'a dyn SemanticGroup {
1411        self.db
1412    }
1413}
1414add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1415add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1416add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity);
1417impl SemanticRewriter<InferenceId, NoError> for InferenceIdReplacer<'_> {
1418    fn internal_rewrite(&mut self, value: &mut InferenceId) -> Result<RewriteResult, NoError> {
1419        if value == &self.from_inference_id {
1420            *value = self.to_inference_id;
1421            Ok(RewriteResult::Modified)
1422        } else {
1423            Ok(RewriteResult::NoChange)
1424        }
1425    }
1426}