cairo_lang_semantic/expr/
inference.rs

1//! Bidirectional type inference.
2
3use std::collections::{BTreeMap, HashMap, VecDeque};
4use std::hash::Hash;
5use std::mem;
6use std::ops::{Deref, DerefMut};
7use std::sync::Arc;
8
9use cairo_lang_debug::DebugWithDb;
10use cairo_lang_defs::ids::{
11    ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId,
12    GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LanguageElementId,
13    LocalVarId, LookupItemId, MemberId, NamedLanguageElementId, ParamId, StructId, TraitConstantId,
14    TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
15};
16use cairo_lang_diagnostics::{DiagnosticAdded, Maybe, skip_diagnostic};
17use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
18use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
19use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
20use cairo_lang_utils::{
21    Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches,
22};
23
24use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
25use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
26use crate::db::SemanticGroup;
27use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
28use crate::expr::inference::canonic::ResultNoErrEx;
29use crate::expr::inference::conform::InferenceConform;
30use crate::expr::objects::*;
31use crate::expr::pattern::*;
32use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
33use crate::items::functions::{
34    ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
35    GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
36    ImplGenericFunctionWithBodyId,
37};
38use crate::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType};
39use crate::items::imp::{
40    GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
41    ImplLookupContext, UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
42};
43use crate::items::trt::{
44    ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId, ConcreteTraitTypeId,
45    ConcreteTraitTypeLongId,
46};
47use crate::substitution::{HasDb, RewriteResult, SemanticRewriter};
48use crate::types::{
49    ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
50    ImplTypeById, ImplTypeId,
51};
52use crate::{
53    ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
54    ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
55    FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
56    Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
57    add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
58};
59
60pub mod canonic;
61pub mod conform;
62pub mod infers;
63pub mod solver;
64
65/// A type variable, created when a generic type argument is not passed, and thus is not known
66/// yet and needs to be inferred.
67#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
68pub struct TypeVar {
69    pub inference_id: InferenceId,
70    pub id: LocalTypeVarId,
71}
72
73/// A const variable, created when a generic const argument is not passed, and thus is not known
74/// yet and needs to be inferred.
75#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
76pub struct ConstVar {
77    pub inference_id: InferenceId,
78    pub id: LocalConstVarId,
79}
80
81/// An id for an inference context. Each inference variable is associated with an inference id.
82#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
83#[debug_db(dyn SemanticGroup + 'static)]
84pub enum InferenceId {
85    LookupItemDeclaration(LookupItemId),
86    LookupItemGenerics(LookupItemId),
87    LookupItemDefinition(LookupItemId),
88    ImplDefTrait(ImplDefId),
89    ImplAliasImplDef(ImplAliasId),
90    GenericParam(GenericParamId),
91    GenericImplParamTrait(GenericParamId),
92    GlobalUseStar(GlobalUseId),
93    Canonical,
94    /// For resolving that will not be used anywhere in the semantic model.
95    NoContext,
96}
97
98/// An impl variable, created when a generic type argument is not passed, and thus is not known
99/// yet and needs to be inferred.
100#[derive(Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
101#[debug_db(dyn SemanticGroup + 'static)]
102pub struct ImplVar {
103    pub inference_id: InferenceId,
104    #[dont_rewrite]
105    pub id: LocalImplVarId,
106    pub concrete_trait_id: ConcreteTraitId,
107    #[dont_rewrite]
108    pub lookup_context: ImplLookupContext,
109}
110impl ImplVar {
111    pub fn intern(&self, db: &dyn SemanticGroup) -> ImplVarId {
112        self.clone().intern(db)
113    }
114}
115
116#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
117pub struct LocalTypeVarId(pub usize);
118#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
119pub struct LocalImplVarId(pub usize);
120
121#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
122pub struct LocalConstVarId(pub usize);
123
124define_short_id!(ImplVarId, ImplVar, SemanticGroup, lookup_intern_impl_var, intern_impl_var);
125impl ImplVarId {
126    pub fn id(&self, db: &dyn SemanticGroup) -> LocalImplVarId {
127        self.lookup_intern(db).id
128    }
129    pub fn concrete_trait_id(&self, db: &dyn SemanticGroup) -> ConcreteTraitId {
130        self.lookup_intern(db).concrete_trait_id
131    }
132    pub fn lookup_context(&self, db: &dyn SemanticGroup) -> ImplLookupContext {
133        self.lookup_intern(db).lookup_context
134    }
135}
136semantic_object_for_id!(ImplVarId, lookup_intern_impl_var, intern_impl_var, ImplVar);
137
138#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject)]
139pub enum InferenceVar {
140    Type(LocalTypeVarId),
141    Const(LocalConstVarId),
142    Impl(LocalImplVarId),
143}
144
145// TODO(spapini): Add to diagnostics.
146#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb)]
147#[debug_db(dyn SemanticGroup + 'static)]
148pub enum InferenceError {
149    /// An inference error wrapping a previously reported error.
150    Reported(DiagnosticAdded),
151    Cycle(InferenceVar),
152    TypeKindMismatch {
153        ty0: TypeId,
154        ty1: TypeId,
155    },
156    ConstKindMismatch {
157        const0: ConstValueId,
158        const1: ConstValueId,
159    },
160    ImplKindMismatch {
161        impl0: ImplId,
162        impl1: ImplId,
163    },
164    GenericArgMismatch {
165        garg0: GenericArgumentId,
166        garg1: GenericArgumentId,
167    },
168    TraitMismatch {
169        trt0: TraitId,
170        trt1: TraitId,
171    },
172    ImplTypeMismatch {
173        impl_id: ImplId,
174        trait_type_id: TraitTypeId,
175        ty0: TypeId,
176        ty1: TypeId,
177    },
178    GenericFunctionMismatch {
179        func0: GenericFunctionId,
180        func1: GenericFunctionId,
181    },
182    ConstNotInferred,
183    // TODO(spapini): These are only used for external interface. Separate them along with the
184    // finalize() function to a wrapper.
185    NoImplsFound(ConcreteTraitId),
186    Ambiguity(Ambiguity),
187    TypeNotInferred(TypeId),
188}
189impl InferenceError {
190    pub fn format(&self, db: &(dyn SemanticGroup + 'static)) -> String {
191        match self {
192            InferenceError::Reported(_) => "Inference error occurred.".into(),
193            InferenceError::Cycle(_var) => "Inference cycle detected".into(),
194            InferenceError::TypeKindMismatch { ty0, ty1 } => {
195                format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
196            }
197            InferenceError::ConstKindMismatch { const0, const1 } => {
198                format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
199            }
200            InferenceError::ImplKindMismatch { impl0, impl1 } => {
201                format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
202            }
203            InferenceError::GenericArgMismatch { garg0, garg1 } => {
204                format!(
205                    "Generic arg mismatch: `{:?}` and `{:?}`.",
206                    garg0.debug(db),
207                    garg1.debug(db)
208                )
209            }
210            InferenceError::TraitMismatch { trt0, trt1 } => {
211                format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
212            }
213            InferenceError::ConstNotInferred => "Failed to infer constant.".into(),
214            InferenceError::NoImplsFound(concrete_trait_id) => {
215                let info = db.core_info();
216                let trait_id = concrete_trait_id.trait_id(db);
217                if trait_id == info.numeric_literal_trt {
218                    let generic_type = extract_matches!(
219                        concrete_trait_id.generic_args(db)[0],
220                        GenericArgumentId::Type
221                    );
222                    return format!(
223                        "Mismatched types. The type `{:?}` cannot be created from a numeric \
224                         literal.",
225                        generic_type.debug(db)
226                    );
227                } else if trait_id == info.string_literal_trt {
228                    let generic_type = extract_matches!(
229                        concrete_trait_id.generic_args(db)[0],
230                        GenericArgumentId::Type
231                    );
232                    return format!(
233                        "Mismatched types. The type `{:?}` cannot be created from a string \
234                         literal.",
235                        generic_type.debug(db)
236                    );
237                }
238                format!(
239                    "Trait has no implementation in context: {:?}.",
240                    concrete_trait_id.debug(db)
241                )
242            }
243            InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
244            InferenceError::TypeNotInferred(ty) => {
245                format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
246            }
247            InferenceError::GenericFunctionMismatch { func0, func1 } => {
248                format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
249            }
250            InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 } => {
251                format!(
252                    "`{}::{}` type mismatch: `{:?}` and `{:?}`.",
253                    impl_id.format(db),
254                    trait_type_id.name(db),
255                    ty0.debug(db),
256                    ty1.debug(db)
257                )
258            }
259        }
260    }
261}
262
263impl InferenceError {
264    pub fn report(
265        &self,
266        diagnostics: &mut SemanticDiagnostics,
267        stable_ptr: SyntaxStablePtrId,
268    ) -> DiagnosticAdded {
269        match self {
270            InferenceError::Reported(diagnostic_added) => *diagnostic_added,
271            _ => diagnostics
272                .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
273        }
274    }
275}
276
277/// This struct is used to ensure that when an inference error occurs, it is properly set in the
278/// `Inference` object, and then properly consumed.
279///
280/// It must not be constructed directly. Instead, it is returned by [Inference::set_error].
281#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
282pub struct ErrorSet;
283
284pub type InferenceResult<T> = Result<T, ErrorSet>;
285
286#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
287pub enum InferenceErrorStatus {
288    Pending,
289    Consumed,
290}
291
292/// A mapping of an impl var's trait items to concrete items
293#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, SemanticObject)]
294pub struct ImplVarTraitItemMappings {
295    /// The trait types of the impl var.
296    types: OrderedHashMap<TraitTypeId, TypeId>,
297    /// The trait constants of the impl var.
298    constants: OrderedHashMap<TraitConstantId, ConstValueId>,
299    /// The trait impls of the impl var.
300    impls: OrderedHashMap<TraitImplId, ImplId>,
301}
302
303/// State of inference.
304#[derive(Debug, DebugWithDb, PartialEq, Eq)]
305#[debug_db(dyn SemanticGroup + 'static)]
306pub struct InferenceData {
307    pub inference_id: InferenceId,
308    /// Current inferred assignment for type variables.
309    pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId>,
310    /// Current inferred assignment for const variables.
311    pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId>,
312    /// Current inferred assignment for impl variables.
313    pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId>,
314    /// Unsolved impl variables mapping to a maps of trait items to a corresponding item variable.
315    /// Upon solution of the trait conforms the fully known item to the variable.
316    pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings>,
317    /// Type variables.
318    pub type_vars: Vec<TypeVar>,
319    /// Const variables.
320    pub const_vars: Vec<ConstVar>,
321    /// Impl variables.
322    pub impl_vars: Vec<ImplVar>,
323    /// Mapping from variables to stable pointers, if exist.
324    pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId>,
325    /// Inference variables that are pending to be solved.
326    pending: VecDeque<LocalImplVarId>,
327    /// Inference variables that have been refuted - no solutions exist.
328    refuted: Vec<LocalImplVarId>,
329    /// Inference variables that have been solved.
330    solved: Vec<LocalImplVarId>,
331    /// Inference variables that are currently ambiguous. May be solved later.
332    ambiguous: Vec<(LocalImplVarId, Ambiguity)>,
333    /// Mapping from impl types to type variables.
334    pub impl_type_bounds: Arc<BTreeMap<ImplTypeById, TypeId>>,
335
336    // Error handling members.
337    /// The current error status.
338    pub error_status: Result<(), InferenceErrorStatus>,
339    /// `Some` only when error_state is Err(Pending).
340    error: Option<InferenceError>,
341    /// `Some` only when error_state is Err(Consumed).
342    consumed_error: Option<DiagnosticAdded>,
343}
344impl InferenceData {
345    pub fn new(inference_id: InferenceId) -> Self {
346        Self {
347            inference_id,
348            type_assignment: OrderedHashMap::default(),
349            impl_assignment: OrderedHashMap::default(),
350            const_assignment: OrderedHashMap::default(),
351            impl_vars_trait_item_mappings: HashMap::new(),
352            type_vars: Vec::new(),
353            impl_vars: Vec::new(),
354            const_vars: Vec::new(),
355            stable_ptrs: HashMap::new(),
356            pending: VecDeque::new(),
357            refuted: Vec::new(),
358            solved: Vec::new(),
359            ambiguous: Vec::new(),
360            impl_type_bounds: Default::default(),
361            error_status: Ok(()),
362            error: None,
363            consumed_error: None,
364        }
365    }
366    pub fn inference<'db, 'b: 'db>(&'db mut self, db: &'b dyn SemanticGroup) -> Inference<'db> {
367        Inference::new(db, self)
368    }
369    pub fn clone_with_inference_id(
370        &self,
371        db: &dyn SemanticGroup,
372        inference_id: InferenceId,
373    ) -> InferenceData {
374        let mut inference_id_replacer =
375            InferenceIdReplacer::new(db, self.inference_id, inference_id);
376        Self {
377            inference_id,
378            type_assignment: self
379                .type_assignment
380                .iter()
381                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
382                .collect(),
383            const_assignment: self
384                .const_assignment
385                .iter()
386                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
387                .collect(),
388            impl_assignment: self
389                .impl_assignment
390                .iter()
391                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
392                .collect(),
393            impl_vars_trait_item_mappings: self
394                .impl_vars_trait_item_mappings
395                .iter()
396                .map(|(k, mappings)| {
397                    (
398                        *k,
399                        ImplVarTraitItemMappings {
400                            types: mappings
401                                .types
402                                .iter()
403                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
404                                .collect(),
405                            constants: mappings
406                                .constants
407                                .iter()
408                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
409                                .collect(),
410                            impls: mappings
411                                .impls
412                                .iter()
413                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
414                                .collect(),
415                        },
416                    )
417                })
418                .collect(),
419            type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
420            const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
421            impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
422            stable_ptrs: self.stable_ptrs.clone(),
423            pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
424            refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
425            solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
426            ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
427            // we do not need to rewrite the impl type bounds, as they all should be var free.
428            impl_type_bounds: self.impl_type_bounds.clone(),
429
430            error_status: self.error_status,
431            error: self.error.clone(),
432            consumed_error: self.consumed_error,
433        }
434    }
435    pub fn temporary_clone(&self) -> InferenceData {
436        Self {
437            inference_id: self.inference_id,
438            type_assignment: self.type_assignment.clone(),
439            const_assignment: self.const_assignment.clone(),
440            impl_assignment: self.impl_assignment.clone(),
441            impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
442            type_vars: self.type_vars.clone(),
443            const_vars: self.const_vars.clone(),
444            impl_vars: self.impl_vars.clone(),
445            stable_ptrs: self.stable_ptrs.clone(),
446            pending: self.pending.clone(),
447            refuted: self.refuted.clone(),
448            solved: self.solved.clone(),
449            ambiguous: self.ambiguous.clone(),
450            impl_type_bounds: self.impl_type_bounds.clone(),
451            error_status: self.error_status,
452            error: self.error.clone(),
453            consumed_error: self.consumed_error,
454        }
455    }
456}
457
458/// State of inference. A system of inference constraints.
459pub struct Inference<'db> {
460    db: &'db dyn SemanticGroup,
461    pub data: &'db mut InferenceData,
462}
463
464impl Deref for Inference<'_> {
465    type Target = InferenceData;
466
467    fn deref(&self) -> &Self::Target {
468        self.data
469    }
470}
471impl DerefMut for Inference<'_> {
472    fn deref_mut(&mut self) -> &mut Self::Target {
473        self.data
474    }
475}
476
477impl std::fmt::Debug for Inference<'_> {
478    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
479        let x = self.data.debug(self.db.elongate());
480        write!(f, "{x:?}")
481    }
482}
483
484impl<'db> Inference<'db> {
485    fn new(db: &'db dyn SemanticGroup, data: &'db mut InferenceData) -> Self {
486        Self { db, data }
487    }
488
489    /// Getter for an [ImplVar].
490    fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar {
491        &self.impl_vars[var_id.0]
492    }
493
494    /// Getter for an impl var assignment.
495    pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId> {
496        self.impl_assignment.get(&var_id).copied()
497    }
498
499    /// Getter for a type var assignment.
500    fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId> {
501        self.type_assignment.get(&var_id).copied()
502    }
503
504    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
505    /// Returns a wrapping TypeId.
506    pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeId {
507        let var = self.new_type_var_raw(stable_ptr);
508
509        TypeLongId::Var(var).intern(self.db)
510    }
511
512    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
513    /// Returns the variable id.
514    pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeVar {
515        let var =
516            TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
517        if let Some(stable_ptr) = stable_ptr {
518            self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
519        }
520        self.type_vars.push(var);
521        var
522    }
523
524    /// Sets the infrence's impl type bounds to the given map, and rewrittes the types so all the
525    /// types are var free.
526    pub fn set_impl_type_bounds(&mut self, impl_type_bounds: OrderedHashMap<ImplTypeId, TypeId>) {
527        let impl_type_bounds_finalized = impl_type_bounds
528            .iter()
529            .filter_map(|(impl_type, ty)| {
530                let rewritten_type = self.rewrite(ty.lookup_intern(self.db)).no_err();
531                if !matches!(rewritten_type, TypeLongId::Var(_)) {
532                    return Some(((*impl_type).into(), rewritten_type.intern(self.db)));
533                }
534                // conformed the var type to the original impl type to remove it from the pending
535                // list.
536                self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
537                None
538            })
539            .collect();
540
541        self.data.impl_type_bounds = Arc::new(impl_type_bounds_finalized);
542    }
543
544    /// Allocates a new [ConstVar] for an unknown consts that needs to be inferred.
545    /// Returns a wrapping [ConstValueId].
546    pub fn new_const_var(
547        &mut self,
548        stable_ptr: Option<SyntaxStablePtrId>,
549        ty: TypeId,
550    ) -> ConstValueId {
551        let var = self.new_const_var_raw(stable_ptr);
552        ConstValue::Var(var, ty).intern(self.db)
553    }
554
555    /// Allocates a new [ConstVar] for an unknown type that needs to be inferred.
556    /// Returns the variable id.
557    pub fn new_const_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> ConstVar {
558        let var = ConstVar {
559            inference_id: self.inference_id,
560            id: LocalConstVarId(self.const_vars.len()),
561        };
562        if let Some(stable_ptr) = stable_ptr {
563            self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
564        }
565        self.const_vars.push(var);
566        var
567    }
568
569    /// Allocates a new [ImplVar] for an unknown type that needs to be inferred.
570    /// Returns a wrapping ImplId.
571    pub fn new_impl_var(
572        &mut self,
573        concrete_trait_id: ConcreteTraitId,
574        stable_ptr: Option<SyntaxStablePtrId>,
575        lookup_context: ImplLookupContext,
576    ) -> ImplId {
577        let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
578        ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
579    }
580
581    /// Allocates a new [ImplVar] for an unknown type that needs to be inferred.
582    /// Returns the variable id.
583    fn new_impl_var_raw(
584        &mut self,
585        lookup_context: ImplLookupContext,
586        concrete_trait_id: ConcreteTraitId,
587        stable_ptr: Option<SyntaxStablePtrId>,
588    ) -> LocalImplVarId {
589        let mut lookup_context = lookup_context;
590        lookup_context.insert_module(concrete_trait_id.trait_id(self.db).module_file_id(self.db).0);
591
592        let id = LocalImplVarId(self.impl_vars.len());
593        if let Some(stable_ptr) = stable_ptr {
594            self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
595        }
596        let var =
597            ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
598        self.impl_vars.push(var);
599        self.pending.push_back(id);
600        id
601    }
602
603    /// Solves the inference system. After a successful solve, there are no more pending impl
604    /// inferences.
605    /// Returns whether the inference was successful. If not, the error may be found by
606    /// `.error_state()`.
607    pub fn solve(&mut self) -> InferenceResult<()> {
608        self.solve_ex().map_err(|(err_set, _)| err_set)
609    }
610
611    /// Same as `solve`, but returns the error stable pointer if an error occurred.
612    fn solve_ex(&mut self) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
613        let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
614        self.pending.extend(ambiguous.map(|(var, _)| var));
615        while let Some(var) = self.pending.pop_front() {
616            // First inference error stops inference.
617            self.solve_single_pending(var).map_err(|err_set| {
618                (err_set, self.stable_ptrs.get(&InferenceVar::Impl(var)).copied())
619            })?;
620        }
621        Ok(())
622    }
623
624    fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
625        if self.impl_assignment.contains_key(&var) {
626            return Ok(());
627        }
628        let solution = match self.impl_var_solution_set(var)? {
629            SolutionSet::None => {
630                self.refuted.push(var);
631                return Ok(());
632            }
633            SolutionSet::Ambiguous(ambiguity) => {
634                self.ambiguous.push((var, ambiguity));
635                return Ok(());
636            }
637            SolutionSet::Unique(solution) => solution,
638        };
639
640        // Solution found. Assign it.
641        self.assign_local_impl(var, solution)?;
642
643        // Something changed.
644        self.solved.push(var);
645        let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
646        self.pending.extend(ambiguous.map(|(var, _)| var));
647
648        Ok(())
649    }
650
651    /// Returns the solution set status for the inference:
652    /// Whether there is a unique solution, multiple solutions, no solutions or an error.
653    pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<()>> {
654        self.solve()?;
655        if !self.refuted.is_empty() {
656            return Ok(SolutionSet::None);
657        }
658        if let Some((_, ambiguity)) = self.ambiguous.first() {
659            return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
660        }
661        assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
662        Ok(SolutionSet::Unique(()))
663    }
664
665    /// Finalizes the inference by inferring uninferred numeric literals as felt252.
666    /// Returns an error and does not report it.
667    pub fn finalize_without_reporting(
668        &mut self,
669    ) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
670        if self.error_status.is_err() {
671            // TODO(yuval): consider adding error location to the set error.
672            return Err((ErrorSet, None));
673        }
674        let info = self.db.core_info();
675        let numeric_trait_id = info.numeric_literal_trt;
676        let felt_ty = info.felt252;
677
678        // Conform all uninferred numeric literals to felt252.
679        loop {
680            let mut changed = false;
681            self.solve_ex()?;
682            for (var, _) in self.ambiguous.clone() {
683                let impl_var = self.impl_var(var).clone();
684                if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
685                    continue;
686                }
687                // Uninferred numeric trait. Resolve as felt252.
688                let ty = extract_matches!(
689                    impl_var.concrete_trait_id.generic_args(self.db)[0],
690                    GenericArgumentId::Type
691                );
692                if self.rewrite(ty).no_err() == felt_ty {
693                    continue;
694                }
695                self.conform_ty(ty, felt_ty).map_err(|err_set| {
696                    (err_set, self.stable_ptrs.get(&InferenceVar::Impl(impl_var.id)).copied())
697                })?;
698                changed = true;
699                break;
700            }
701            if !changed {
702                break;
703            }
704        }
705        assert!(
706            self.pending.is_empty(),
707            "pending should all be solved by this point. Guaranteed by solve()."
708        );
709
710        let Some((var, err)) = self.first_undetermined_variable() else {
711            return Ok(());
712        };
713        Err((self.set_error(err), self.stable_ptrs.get(&var).copied()))
714    }
715
716    /// Finalizes the inference and report diagnostics if there are any errors.
717    /// All the remaining type vars are mapped to the `missing` type, to prevent additional
718    /// diagnostics.
719    pub fn finalize(
720        &mut self,
721        diagnostics: &mut SemanticDiagnostics,
722        stable_ptr: SyntaxStablePtrId,
723    ) {
724        if let Err((err_set, err_stable_ptr)) = self.finalize_without_reporting() {
725            let diag = self.report_on_pending_error(
726                err_set,
727                diagnostics,
728                err_stable_ptr.unwrap_or(stable_ptr),
729            );
730
731            let ty_missing = TypeId::missing(self.db, diag);
732            for var in &self.data.type_vars {
733                self.data.type_assignment.entry(var.id).or_insert(ty_missing);
734            }
735        }
736    }
737
738    /// Retrieves the first variable that is still not inferred, or None, if everything is
739    /// inferred.
740    /// Does not set the error but return it, which is ok as this is a private helper function.
741    fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError)> {
742        if let Some(var) = self.refuted.first().copied() {
743            let impl_var = self.impl_var(var).clone();
744            let concrete_trait_id = impl_var.concrete_trait_id;
745            let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
746            return Some((
747                InferenceVar::Impl(var),
748                InferenceError::NoImplsFound(concrete_trait_id),
749            ));
750        }
751        let mut fallback_ret = None;
752        if let Some((var, ambiguity)) = self.ambiguous.first() {
753            // Note: do not rewrite `ambiguity`, since it is expressed in canonical variables.
754            let ret =
755                Some((InferenceVar::Impl(*var), InferenceError::Ambiguity(ambiguity.clone())));
756            if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
757                return ret;
758            } else {
759                fallback_ret = ret;
760            }
761        }
762        for (id, var) in self.type_vars.iter().enumerate() {
763            if self.type_assignment(LocalTypeVarId(id)).is_none() {
764                let ty = TypeLongId::Var(*var).intern(self.db);
765                return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
766            }
767        }
768        for (id, var) in self.const_vars.iter().enumerate() {
769            if !self.const_assignment.contains_key(&LocalConstVarId(id)) {
770                let infernence_var = InferenceVar::Const(var.id);
771                return Some((infernence_var, InferenceError::ConstNotInferred));
772            }
773        }
774        fallback_ret
775    }
776
777    /// Assigns a value to a local impl variable id. See assign_impl().
778    fn assign_local_impl(
779        &mut self,
780        var: LocalImplVarId,
781        impl_id: ImplId,
782    ) -> InferenceResult<ImplId> {
783        let concrete_trait = impl_id
784            .concrete_trait(self.db)
785            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
786        self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
787        if let Some(other_impl) = self.impl_assignment(var) {
788            return self.conform_impl(impl_id, other_impl);
789        }
790        if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
791        {
792            return Err(self.set_error(InferenceError::Cycle(InferenceVar::Impl(var))));
793        }
794        self.impl_assignment.insert(var, impl_id);
795        if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
796            for (trait_type_id, ty) in mappings.types {
797                let impl_ty = self
798                    .db
799                    .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_type_id, self.db))
800                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
801                if let Err(err_set) = self.conform_ty(ty, impl_ty) {
802                    // Override the error with ImplTypeMismatch.
803                    let ty0 = self.rewrite(ty).no_err();
804                    let ty1 = self.rewrite(impl_ty).no_err();
805
806                    self.error =
807                        Some(InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 });
808                    return Err(err_set);
809                }
810            }
811            for (trait_constant, constant_id) in mappings.constants {
812                let concrete_impl_constant = self
813                    .db
814                    .impl_constant_concrete_implized_value(ImplConstantId::new(
815                        impl_id,
816                        trait_constant,
817                        self.db,
818                    ))
819                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
820                self.conform_const(constant_id, concrete_impl_constant)?;
821            }
822            for (trait_impl, inner_impl_id) in mappings.impls {
823                let concrete_impl_impl = self
824                    .db
825                    .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
826                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
827                self.conform_impl(inner_impl_id, concrete_impl_impl)?;
828            }
829        }
830        Ok(impl_id)
831    }
832
833    /// Tries to assigns value to an [ImplVarId]. Return the assigned impl, or an error.
834    fn assign_impl(&mut self, var_id: ImplVarId, impl_id: ImplId) -> InferenceResult<ImplId> {
835        let var = var_id.lookup_intern(self.db);
836        if var.inference_id != self.inference_id {
837            return Err(self.set_error(InferenceError::ImplKindMismatch {
838                impl0: ImplLongId::ImplVar(var_id).intern(self.db),
839                impl1: impl_id,
840            }));
841        }
842        self.assign_local_impl(var.id, impl_id)
843    }
844
845    /// Assigns a value to a [TypeVar]. Return the assigned type, or an error.
846    /// Assumes the variable is not already assigned.
847    fn assign_ty(&mut self, var: TypeVar, ty: TypeId) -> InferenceResult<TypeId> {
848        if var.inference_id != self.inference_id {
849            return Err(self.set_error(InferenceError::TypeKindMismatch {
850                ty0: TypeLongId::Var(var).intern(self.db),
851                ty1: ty,
852            }));
853        }
854        assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
855        let inference_var = InferenceVar::Type(var.id);
856        if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
857            return Err(self.set_error(InferenceError::Cycle(inference_var)));
858        }
859        // If assigning var to var - making sure assigning to the lower id for proper canonization.
860        if let TypeLongId::Var(other) = ty.lookup_intern(self.db) {
861            if other.inference_id == self.inference_id && other.id.0 > var.id.0 {
862                let var_ty = TypeLongId::Var(var).intern(self.db);
863                self.type_assignment.insert(other.id, var_ty);
864                return Ok(var_ty);
865            }
866        }
867        self.type_assignment.insert(var.id, ty);
868        Ok(ty)
869    }
870
871    /// Assigns a value to a [ConstVar]. Return the assigned const, or an error.
872    /// Assumes the variable is not already assigned.
873    fn assign_const(&mut self, var: ConstVar, id: ConstValueId) -> InferenceResult<ConstValueId> {
874        if var.inference_id != self.inference_id {
875            return Err(self.set_error(InferenceError::ConstKindMismatch {
876                const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
877                    .intern(self.db),
878                const1: id,
879            }));
880        }
881
882        self.const_assignment.insert(var.id, id);
883        Ok(id)
884    }
885
886    /// Computes the solution set for an impl variable with a recursive query.
887    fn impl_var_solution_set(
888        &mut self,
889        var: LocalImplVarId,
890    ) -> InferenceResult<SolutionSet<ImplId>> {
891        let impl_var = self.impl_var(var).clone();
892        // Update the concrete trait of the impl var.
893        let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
894        self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
895        let impl_var_trait_item_mappings =
896            self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
897        let solution_set = self.trait_solution_set(
898            concrete_trait_id,
899            impl_var_trait_item_mappings,
900            impl_var.lookup_context,
901        )?;
902        Ok(match solution_set {
903            SolutionSet::None => SolutionSet::None,
904            SolutionSet::Unique((canonical_impl, canonicalizer)) => {
905                SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
906            }
907            SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
908        })
909    }
910
911    /// Computes the solution set for a trait with a recursive query.
912    pub fn trait_solution_set(
913        &mut self,
914        concrete_trait_id: ConcreteTraitId,
915        impl_var_trait_item_mappings: ImplVarTraitItemMappings,
916        mut lookup_context: ImplLookupContext,
917    ) -> InferenceResult<SolutionSet<(CanonicalImpl, CanonicalMapping)>> {
918        let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
919        // TODO(spapini): This is done twice. Consider doing it only here.
920        let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
921        enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
922
923        // Don't try to resolve impls if the first generic param is a variable.
924        let generic_args = concrete_trait_id.generic_args(self.db);
925        match generic_args.first() {
926            Some(GenericArgumentId::Type(ty)) => {
927                if let TypeLongId::Var(_) = ty.lookup_intern(self.db) {
928                    // Don't try to infer such impls.
929                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
930                }
931            }
932            Some(GenericArgumentId::Impl(imp)) => {
933                // Don't try to infer such impls.
934                if let ImplLongId::ImplVar(_) = imp.lookup_intern(self.db) {
935                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
936                }
937            }
938            Some(GenericArgumentId::Constant(const_value)) => {
939                if let ConstValue::Var(_, _) = const_value.lookup_intern(self.db) {
940                    // Don't try to infer such impls.
941                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
942                }
943            }
944            _ => {}
945        };
946        let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
947            self.db,
948            self.inference_id,
949            concrete_trait_id,
950            impl_var_trait_item_mappings,
951        );
952        // impl_type_bounds order is deterimend by the generic params of the function and therefore
953        // is consistent.
954        let solution_set = match self.db.canonic_trait_solutions(
955            canonical_trait,
956            lookup_context,
957            (*self.data.impl_type_bounds).clone(),
958        ) {
959            Ok(solution_set) => solution_set,
960            Err(err) => return Err(self.set_error(err)),
961        };
962        match solution_set {
963            SolutionSet::None => Ok(SolutionSet::None),
964            SolutionSet::Unique(canonical_impl) => {
965                Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
966            }
967            SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
968        }
969    }
970
971    /// Validate that the given impl is valid based on its negative impls arguments.
972    /// Returns `SolutionSet::Unique(canonical_impl)` if the impl is valid and
973    /// SolutionSet::Ambiguous(...) otherwise.
974    fn validate_neg_impls(
975        &mut self,
976        lookup_context: &ImplLookupContext,
977        canonical_impl: CanonicalImpl,
978    ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
979        /// Validates that no solution set is found for the negative impls.
980        fn validate_no_solution_set(
981            inference: &mut Inference<'_>,
982            canonical_impl: CanonicalImpl,
983            lookup_context: &ImplLookupContext,
984            negative_impls_concrete_traits: impl Iterator<Item = Maybe<ConcreteTraitId>>,
985        ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
986            for concrete_trait_id in negative_impls_concrete_traits {
987                let concrete_trait_id = concrete_trait_id.map_err(|diag_added| {
988                    inference.set_error(InferenceError::Reported(diag_added))
989                })?;
990                for garg in concrete_trait_id.generic_args(inference.db) {
991                    let GenericArgumentId::Type(ty) = garg else {
992                        continue;
993                    };
994                    let ty = inference.rewrite(ty).no_err();
995                    // If the negative impl has a generic argument that is not fully
996                    // concrete we can't tell if we should rule out the candidate impl.
997                    // For example if we have -TypeEqual<S, T> we can't tell if S and
998                    // T are going to be assigned the same concrete type.
999                    // We return `SolutionSet::Ambiguous` here to indicate that more
1000                    // information is needed.
1001                    // Closure can only have one type, even if it's not fully concrete, so can use
1002                    // it and not get ambiguity.
1003                    if !matches!(ty.lookup_intern(inference.db), TypeLongId::Closure(_))
1004                        && !ty.is_fully_concrete(inference.db)
1005                    {
1006                        // TODO(ilya): Try to detect the ambiguity earlier in the
1007                        // inference process.
1008                        return Ok(SolutionSet::Ambiguous(
1009                            Ambiguity::NegativeImplWithUnresolvedGenericArgs {
1010                                impl_id: canonical_impl.0,
1011                                ty,
1012                            },
1013                        ));
1014                    }
1015                }
1016
1017                if !matches!(
1018                    inference.trait_solution_set(
1019                        concrete_trait_id,
1020                        ImplVarTraitItemMappings::default(),
1021                        lookup_context.clone()
1022                    )?,
1023                    SolutionSet::None
1024                ) {
1025                    // If a negative impl has an impl, then we should skip it.
1026                    return Ok(SolutionSet::None);
1027                }
1028            }
1029
1030            Ok(SolutionSet::Unique(canonical_impl))
1031        }
1032        match canonical_impl.0.lookup_intern(self.db) {
1033            ImplLongId::Concrete(concrete_impl) => {
1034                let substitution = concrete_impl
1035                    .substitution(self.db)
1036                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1037                let generic_params = self
1038                    .db
1039                    .impl_def_generic_params(concrete_impl.impl_def_id(self.db))
1040                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1041                let concrete_traits = generic_params
1042                    .iter()
1043                    .filter_map(|generic_param| {
1044                        try_extract_matches!(generic_param, GenericParam::NegImpl)
1045                    })
1046                    .map(|generic_param| {
1047                        substitution
1048                            .substitute(self.db, generic_param.clone())
1049                            .and_then(|generic_param| generic_param.concrete_trait)
1050                    });
1051                validate_no_solution_set(self, canonical_impl, lookup_context, concrete_traits)
1052            }
1053            ImplLongId::GeneratedImpl(generated_impl) => validate_no_solution_set(
1054                self,
1055                canonical_impl,
1056                lookup_context,
1057                generated_impl
1058                    .lookup_intern(self.db)
1059                    .generic_params
1060                    .iter()
1061                    .filter_map(|generic_param| {
1062                        try_extract_matches!(generic_param, GenericParam::NegImpl)
1063                    })
1064                    .map(|generic_param| generic_param.concrete_trait),
1065            ),
1066            ImplLongId::GenericParameter(_)
1067            | ImplLongId::ImplVar(_)
1068            | ImplLongId::ImplImpl(_)
1069            | ImplLongId::SelfImpl(_) => Ok(SolutionSet::Unique(canonical_impl)),
1070        }
1071    }
1072
1073    // Error handling methods
1074    // ======================
1075
1076    /// Sets an error in the inference state.
1077    /// Does nothing if an error is already set.
1078    /// Returns an `ErrorSet` that can be used in reporting the error.
1079    pub fn set_error(&mut self, err: InferenceError) -> ErrorSet {
1080        if self.error_status.is_err() {
1081            return ErrorSet;
1082        }
1083        self.error_status = if let InferenceError::Reported(diag_added) = err {
1084            self.consumed_error = Some(diag_added);
1085            Err(InferenceErrorStatus::Consumed)
1086        } else {
1087            self.error = Some(err);
1088            Err(InferenceErrorStatus::Pending)
1089        };
1090        ErrorSet
1091    }
1092
1093    /// Returns whether an error is set (either pending or consumed).
1094    pub fn is_error_set(&self) -> InferenceResult<()> {
1095        if self.error_status.is_err() { Err(ErrorSet) } else { Ok(()) }
1096    }
1097
1098    /// Consumes the error but doesn't report it. If there is no error, or the error is consumed,
1099    /// returns None. This should be used with caution. Always prefer to use
1100    /// (1) `report_on_pending_error` if possible, or (2) `consume_reported_error` which is safer.
1101    ///
1102    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1103    pub fn consume_error_without_reporting(&mut self, err_set: ErrorSet) -> Option<InferenceError> {
1104        self.consume_error_inner(err_set, skip_diagnostic())
1105    }
1106
1107    /// Consumes the error that is already reported. If there is no error, or the error is consumed,
1108    /// does nothing. This should be used with caution. Always prefer to use
1109    /// `report_on_pending_error` if possible.
1110    ///
1111    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1112    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1113    pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1114        self.consume_error_inner(err_set, diag_added);
1115    }
1116
1117    /// Consumes the error and returns it, but doesn't report it. If there is no error, or the error
1118    /// is already consumed, returns None. This should be used with caution. Always prefer to use
1119    /// `report_on_pending_error` if possible.
1120    ///
1121    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1122    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1123    fn consume_error_inner(
1124        &mut self,
1125        _err_set: ErrorSet,
1126        diag_added: DiagnosticAdded,
1127    ) -> Option<InferenceError> {
1128        if self.error_status != Err(InferenceErrorStatus::Pending) {
1129            return None;
1130            // panic!("consume_error when there is no pending error");
1131        }
1132        self.error_status = Err(InferenceErrorStatus::Consumed);
1133        self.consumed_error = Some(diag_added);
1134        self.error.take()
1135    }
1136
1137    /// Consumes the pending error, if any, and reports it.
1138    /// Should only be called when an error is set, otherwise it panics.
1139    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1140    /// If an error was set but it's already consumed, it doesn't report it again but returns the
1141    /// stored `DiagnosticAdded`.
1142    pub fn report_on_pending_error(
1143        &mut self,
1144        _err_set: ErrorSet,
1145        diagnostics: &mut SemanticDiagnostics,
1146        stable_ptr: SyntaxStablePtrId,
1147    ) -> DiagnosticAdded {
1148        let Err(state_error) = self.error_status else {
1149            panic!("report_on_pending_error should be called only on error");
1150        };
1151        match state_error {
1152            InferenceErrorStatus::Consumed => self
1153                .consumed_error
1154                .expect("consumed_error is not set although error_status is Err(Consumed)"),
1155            InferenceErrorStatus::Pending => {
1156                let diag_added = match mem::take(&mut self.error)
1157                    .expect("error is not set although error_status is Err(Pending)")
1158                {
1159                    InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1160                        // If we have other diagnostics, there is no need to TypeNotInferred.
1161
1162                        // Note that `diagnostics` is not empty, so it is safe to return
1163                        // 'DiagnosticAdded' here.
1164                        skip_diagnostic()
1165                    }
1166                    diag => diag.report(diagnostics, stable_ptr),
1167                };
1168
1169                self.error_status = Err(InferenceErrorStatus::Consumed);
1170                self.consumed_error = Some(diag_added);
1171                diag_added
1172            }
1173        }
1174    }
1175
1176    /// If the current status is of a pending error, reports an alternative diagnostic, by calling
1177    /// `report`, and consumes the error. Otherwise, does nothing.
1178    pub fn report_modified_if_pending(
1179        &mut self,
1180        err_set: ErrorSet,
1181        report: impl FnOnce() -> DiagnosticAdded,
1182    ) {
1183        if self.error_status == Err(InferenceErrorStatus::Pending) {
1184            self.consume_reported_error(err_set, report());
1185        }
1186    }
1187}
1188
1189impl<'a> HasDb<&'a dyn SemanticGroup> for Inference<'a> {
1190    fn get_db(&self) -> &'a dyn SemanticGroup {
1191        self.db
1192    }
1193}
1194add_basic_rewrites!(<'a>, Inference<'a>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue);
1195add_expr_rewrites!(<'a>, Inference<'a>, NoError, @exclude);
1196add_rewrite!(<'a>, Inference<'a>, NoError, Ambiguity);
1197impl SemanticRewriter<TypeId, NoError> for Inference<'_> {
1198    fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
1199        if value.is_var_free(self.db) {
1200            return Ok(RewriteResult::NoChange);
1201        }
1202        value.default_rewrite(self)
1203    }
1204}
1205impl SemanticRewriter<ImplId, NoError> for Inference<'_> {
1206    fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
1207        if value.is_var_free(self.db) {
1208            return Ok(RewriteResult::NoChange);
1209        }
1210        value.default_rewrite(self)
1211    }
1212}
1213impl SemanticRewriter<TypeLongId, NoError> for Inference<'_> {
1214    fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
1215        match value {
1216            TypeLongId::Var(var) => {
1217                if let Some(type_id) = self.type_assignment.get(&var.id) {
1218                    let mut long_type_id = type_id.lookup_intern(self.db);
1219                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1220                        *self.type_assignment.get_mut(&var.id).unwrap() =
1221                            long_type_id.clone().intern(self.db);
1222                    }
1223                    *value = long_type_id;
1224                    return Ok(RewriteResult::Modified);
1225                }
1226            }
1227            TypeLongId::ImplType(impl_type_id) => {
1228                if let Some(type_id) = self.impl_type_bounds.get(&((*impl_type_id).into())) {
1229                    *value = type_id.lookup_intern(self.db);
1230                    self.internal_rewrite(value)?;
1231                    return Ok(RewriteResult::Modified);
1232                }
1233                let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1234                let impl_id = impl_type_id.impl_id();
1235                let trait_ty = impl_type_id.ty();
1236                return Ok(match impl_id.lookup_intern(self.db) {
1237                    ImplLongId::GenericParameter(_)
1238                    | ImplLongId::SelfImpl(_)
1239                    | ImplLongId::ImplImpl(_) => impl_type_id_rewrite_result,
1240                    ImplLongId::Concrete(_) => {
1241                        if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1242                            impl_id, trait_ty, self.db,
1243                        )) {
1244                            *value = self.rewrite(ty).no_err().lookup_intern(self.db);
1245                            RewriteResult::Modified
1246                        } else {
1247                            impl_type_id_rewrite_result
1248                        }
1249                    }
1250                    ImplLongId::ImplVar(var) => {
1251                        *value = self.rewritten_impl_type(var, trait_ty).lookup_intern(self.db);
1252                        return Ok(RewriteResult::Modified);
1253                    }
1254                    ImplLongId::GeneratedImpl(generated) => {
1255                        *value = self
1256                            .rewrite(
1257                                *generated
1258                                    .lookup_intern(self.db)
1259                                    .impl_items
1260                                    .0
1261                                    .get(&impl_type_id.ty())
1262                                    .unwrap(),
1263                            )
1264                            .no_err()
1265                            .lookup_intern(self.db);
1266                        RewriteResult::Modified
1267                    }
1268                });
1269            }
1270            _ => {}
1271        }
1272        value.default_rewrite(self)
1273    }
1274}
1275impl SemanticRewriter<ConstValue, NoError> for Inference<'_> {
1276    fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
1277        match value {
1278            ConstValue::Var(var, _) => {
1279                return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1280                    let mut const_value = const_value_id.lookup_intern(self.db);
1281                    if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1282                        *self.const_assignment.get_mut(&var.id).unwrap() =
1283                            const_value.clone().intern(self.db);
1284                    }
1285                    *value = const_value;
1286                    RewriteResult::Modified
1287                } else {
1288                    RewriteResult::NoChange
1289                });
1290            }
1291            ConstValue::ImplConstant(impl_constant_id) => {
1292                let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1293                let impl_id = impl_constant_id.impl_id();
1294                let trait_constant = impl_constant_id.trait_constant_id();
1295                return Ok(match impl_id.lookup_intern(self.db) {
1296                    ImplLongId::GenericParameter(_)
1297                    | ImplLongId::SelfImpl(_)
1298                    | ImplLongId::GeneratedImpl(_)
1299                    | ImplLongId::ImplImpl(_) => impl_constant_id_rewrite_result,
1300                    ImplLongId::Concrete(_) => {
1301                        if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1302                            ImplConstantId::new(impl_id, trait_constant, self.db),
1303                        ) {
1304                            *value = self.rewrite(constant).no_err().lookup_intern(self.db);
1305                            RewriteResult::Modified
1306                        } else {
1307                            impl_constant_id_rewrite_result
1308                        }
1309                    }
1310                    ImplLongId::ImplVar(var) => {
1311                        *value = self
1312                            .rewritten_impl_constant(var, trait_constant)
1313                            .lookup_intern(self.db);
1314                        return Ok(RewriteResult::Modified);
1315                    }
1316                });
1317            }
1318            _ => {}
1319        }
1320        value.default_rewrite(self)
1321    }
1322}
1323impl SemanticRewriter<ImplLongId, NoError> for Inference<'_> {
1324    fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
1325        match value {
1326            ImplLongId::ImplVar(var) => {
1327                let long_id = var.lookup_intern(self.db);
1328                // Relax the candidates.
1329                let impl_var_id = long_id.id;
1330                if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1331                    let mut long_impl_id = impl_id.lookup_intern(self.db);
1332                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1333                        *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1334                            long_impl_id.clone().intern(self.db);
1335                    }
1336                    *value = long_impl_id;
1337                    return Ok(RewriteResult::Modified);
1338                }
1339            }
1340            ImplLongId::ImplImpl(impl_impl_id) => {
1341                let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1342                let impl_id = impl_impl_id.impl_id();
1343                return Ok(match impl_id.lookup_intern(self.db) {
1344                    ImplLongId::GenericParameter(_)
1345                    | ImplLongId::SelfImpl(_)
1346                    | ImplLongId::GeneratedImpl(_)
1347                    | ImplLongId::ImplImpl(_) => impl_impl_id_rewrite_result,
1348                    ImplLongId::Concrete(_) => {
1349                        if let Ok(imp) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1350                            *value = self.rewrite(imp).no_err().lookup_intern(self.db);
1351                            RewriteResult::Modified
1352                        } else {
1353                            impl_impl_id_rewrite_result
1354                        }
1355                    }
1356                    ImplLongId::ImplVar(var) => {
1357                        if let Ok(concrete_trait_impl) =
1358                            impl_impl_id.concrete_trait_impl_id(self.db)
1359                        {
1360                            *value = self
1361                                .rewritten_impl_impl(var, concrete_trait_impl)
1362                                .lookup_intern(self.db);
1363                            return Ok(RewriteResult::Modified);
1364                        } else {
1365                            impl_impl_id_rewrite_result
1366                        }
1367                    }
1368                });
1369            }
1370
1371            _ => {}
1372        }
1373        if value.is_var_free(self.db) {
1374            return Ok(RewriteResult::NoChange);
1375        }
1376        value.default_rewrite(self)
1377    }
1378}
1379
1380struct InferenceIdReplacer<'a> {
1381    db: &'a dyn SemanticGroup,
1382    from_inference_id: InferenceId,
1383    to_inference_id: InferenceId,
1384}
1385impl<'a> InferenceIdReplacer<'a> {
1386    fn new(
1387        db: &'a dyn SemanticGroup,
1388        from_inference_id: InferenceId,
1389        to_inference_id: InferenceId,
1390    ) -> Self {
1391        Self { db, from_inference_id, to_inference_id }
1392    }
1393}
1394impl<'a> HasDb<&'a dyn SemanticGroup> for InferenceIdReplacer<'a> {
1395    fn get_db(&self) -> &'a dyn SemanticGroup {
1396        self.db
1397    }
1398}
1399add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1400add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1401add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity);
1402impl SemanticRewriter<InferenceId, NoError> for InferenceIdReplacer<'_> {
1403    fn internal_rewrite(&mut self, value: &mut InferenceId) -> Result<RewriteResult, NoError> {
1404        if value == &self.from_inference_id {
1405            *value = self.to_inference_id;
1406            Ok(RewriteResult::Modified)
1407        } else {
1408            Ok(RewriteResult::NoChange)
1409        }
1410    }
1411}