Skip to main content

cairo_lang_semantic/expr/
inference.rs

1//! Bidirectional type inference.
2
3#![expect(clippy::disallowed_types)]
4
5use std::collections::{BTreeMap, HashMap};
6use std::hash::Hash;
7use std::ops::{Deref, DerefMut};
8use std::sync::Arc;
9
10use cairo_lang_debug::DebugWithDb;
11use cairo_lang_defs::ids::{
12    ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericKind,
13    GenericParamId, GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LocalVarId,
14    LookupItemId, MacroCallId, MemberId, NamedLanguageElementId, ParamId, StructId,
15    TraitConstantId, TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
16};
17use cairo_lang_diagnostics::{DiagnosticAdded, skip_diagnostic};
18use cairo_lang_proc_macros::{DebugWithDb, HeapSize, SemanticObject};
19use cairo_lang_syntax::node::TypedStablePtr;
20use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
21use cairo_lang_utils::deque::Deque;
22use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
23use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
24use cairo_lang_utils::{Intern, define_short_id, extract_matches};
25use salsa::Database;
26
27use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
28use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
29use crate::corelib::{CorelibSemantic, is_valid_literal_type};
30use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
31use crate::expr::inference::canonic::ResultNoErrEx;
32use crate::expr::inference::conform::InferenceConform;
33use crate::expr::inference::solver::SemanticSolver;
34use crate::expr::objects::*;
35use crate::expr::pattern::*;
36use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
37use crate::items::functions::{
38    ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
39    GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
40    ImplGenericFunctionWithBodyId,
41};
42use crate::items::generics::{
43    GenericParamConst, GenericParamImpl, GenericParamSemantic, GenericParamType,
44};
45use crate::items::imp::{
46    GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
47    ImplLookupContext, ImplLookupContextId, ImplSemantic, NegativeImplId, NegativeImplLongId,
48    UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
49};
50use crate::items::trt::{
51    ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId, ConcreteTraitTypeId,
52    ConcreteTraitTypeLongId,
53};
54use crate::substitution::{GenericSubstitution, HasDb, RewriteResult, SemanticRewriter};
55use crate::types::{
56    ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
57    ImplTypeById, ImplTypeId, get_impl_at_context,
58};
59use crate::{
60    ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
61    ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
62    FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
63    Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
64    add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
65};
66
67pub mod canonic;
68pub mod conform;
69pub mod infers;
70pub mod solver;
71
72/// A type variable, created when a generic type argument is not passed, and thus is not known
73/// yet and needs to be inferred.
74#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, HeapSize, salsa::Update)]
75pub struct TypeVar<'db> {
76    pub inference_id: InferenceId<'db>,
77    pub id: LocalTypeVarId,
78}
79
80/// A const variable, created when a generic const argument is not passed, and thus is not known
81/// yet and needs to be inferred.
82#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, HeapSize, salsa::Update)]
83pub struct ConstVar<'db> {
84    pub inference_id: InferenceId<'db>,
85    pub id: LocalConstVarId,
86}
87
88/// An id for an inference context. Each inference variable is associated with an inference id.
89#[derive(
90    Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, HeapSize, salsa::Update,
91)]
92#[debug_db(dyn Database)]
93pub enum InferenceId<'db> {
94    LookupItemDeclaration(LookupItemId<'db>),
95    LookupItemGenerics(LookupItemId<'db>),
96    LookupItemDefinition(LookupItemId<'db>),
97    ImplDefTrait(ImplDefId<'db>),
98    ImplAliasImplDef(ImplAliasId<'db>),
99    GenericParam(GenericParamId<'db>),
100    GenericImplParamTrait(GenericParamId<'db>),
101    GlobalUseStar(GlobalUseId<'db>),
102    MacroCall(MacroCallId<'db>),
103    Canonical,
104    /// For resolving that will not be used anywhere in the semantic model.
105    NoContext,
106}
107
108/// An impl variable, created when a generic type argument is not passed, and thus is not known
109/// yet and needs to be inferred.
110#[derive(
111    Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, HeapSize, salsa::Update,
112)]
113#[debug_db(dyn Database)]
114pub struct ImplVar<'db> {
115    pub inference_id: InferenceId<'db>,
116    #[dont_rewrite]
117    pub id: LocalImplVarId,
118    pub concrete_trait_id: ConcreteTraitId<'db>,
119    #[dont_rewrite]
120    pub lookup_context: ImplLookupContextId<'db>,
121}
122impl<'db> ImplVar<'db> {
123    pub fn intern(&self, db: &'db dyn Database) -> ImplVarId<'db> {
124        self.clone().intern(db)
125    }
126}
127
128/// A negative impl variable.
129#[derive(
130    Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, HeapSize, salsa::Update,
131)]
132#[debug_db(dyn Database)]
133pub struct NegativeImplVar<'db> {
134    pub inference_id: InferenceId<'db>,
135    #[dont_rewrite]
136    pub id: LocalNegativeImplVarId,
137    pub concrete_trait_id: ConcreteTraitId<'db>,
138    #[dont_rewrite]
139    pub lookup_context: ImplLookupContextId<'db>,
140}
141
142#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
143pub struct LocalTypeVarId(pub usize);
144#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
145pub struct LocalImplVarId(pub usize);
146
147#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
148pub struct LocalNegativeImplVarId(pub usize);
149
150#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
151pub struct LocalConstVarId(pub usize);
152
153define_short_id!(ImplVarId, ImplVar<'db>);
154impl<'db> ImplVarId<'db> {
155    pub fn id(&self, db: &dyn Database) -> LocalImplVarId {
156        self.long(db).id
157    }
158    pub fn concrete_trait_id(&self, db: &'db dyn Database) -> ConcreteTraitId<'db> {
159        self.long(db).concrete_trait_id
160    }
161    pub fn lookup_context(&self, db: &'db dyn Database) -> ImplLookupContextId<'db> {
162        self.long(db).lookup_context
163    }
164}
165semantic_object_for_id!(ImplVarId, ImplVar<'a>);
166
167define_short_id!(NegativeImplVarId, NegativeImplVar<'db>);
168impl<'db> NegativeImplVarId<'db> {
169    pub fn id(&self, db: &dyn Database) -> LocalNegativeImplVarId {
170        self.long(db).id
171    }
172    pub fn concrete_trait_id(&self, db: &'db dyn Database) -> ConcreteTraitId<'db> {
173        self.long(db).concrete_trait_id
174    }
175    pub fn lookup_context(&self, db: &'db dyn Database) -> ImplLookupContextId<'db> {
176        self.long(db).lookup_context
177    }
178}
179semantic_object_for_id!(NegativeImplVarId, NegativeImplVar<'a>);
180
181#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject, salsa::Update)]
182pub enum InferenceVar {
183    Type(LocalTypeVarId),
184    Const(LocalConstVarId),
185    Impl(LocalImplVarId),
186    NegativeImpl(LocalNegativeImplVarId),
187}
188
189// TODO(spapini): Add to diagnostics.
190#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb, salsa::Update)]
191#[debug_db(dyn Database)]
192pub enum InferenceError<'db> {
193    /// An inference error wrapping a previously reported error.
194    Reported(DiagnosticAdded),
195    Cycle(InferenceVar),
196    TypeKindMismatch {
197        ty0: TypeId<'db>,
198        ty1: TypeId<'db>,
199    },
200    ConstKindMismatch {
201        const0: ConstValueId<'db>,
202        const1: ConstValueId<'db>,
203    },
204    ImplKindMismatch {
205        impl0: ImplId<'db>,
206        impl1: ImplId<'db>,
207    },
208    NegativeImplKindMismatch {
209        impl0: NegativeImplId<'db>,
210        impl1: NegativeImplId<'db>,
211    },
212    GenericArgMismatch {
213        garg0: GenericArgumentId<'db>,
214        garg1: GenericArgumentId<'db>,
215    },
216    TraitMismatch {
217        trt0: TraitId<'db>,
218        trt1: TraitId<'db>,
219    },
220    ImplTypeMismatch {
221        impl_id: ImplId<'db>,
222        trait_type_id: TraitTypeId<'db>,
223        ty0: TypeId<'db>,
224        ty1: TypeId<'db>,
225    },
226    GenericFunctionMismatch {
227        func0: GenericFunctionId<'db>,
228        func1: GenericFunctionId<'db>,
229    },
230    ConstNotInferred,
231    // TODO(spapini): These are only used for external interface. Separate them along with the
232    // finalize() function to a wrapper.
233    NoImplsFound(ConcreteTraitId<'db>),
234    NoNegativeImplsFound(ConcreteTraitId<'db>),
235    Ambiguity(Ambiguity<'db>),
236    TypeNotInferred(TypeId<'db>),
237    /// A numeric literal was bound to a type that does not accept numeric literals.
238    InvalidLiteralType(TypeId<'db>),
239}
240impl<'db> InferenceError<'db> {
241    pub fn format(&self, db: &dyn Database) -> String {
242        match self {
243            InferenceError::Reported(_) => "Inference error occurred.".into(),
244            InferenceError::Cycle(_var) => "Inference cycle detected".into(),
245            InferenceError::TypeKindMismatch { ty0, ty1 } => {
246                format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
247            }
248            InferenceError::ConstKindMismatch { const0, const1 } => {
249                format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
250            }
251            InferenceError::ImplKindMismatch { impl0, impl1 } => {
252                format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
253            }
254            InferenceError::NegativeImplKindMismatch { impl0, impl1 } => {
255                format!(
256                    "Negative impl mismatch: `{:?}` and `{:?}`.",
257                    impl0.debug(db),
258                    impl1.debug(db)
259                )
260            }
261            InferenceError::GenericArgMismatch { garg0, garg1 } => {
262                format!(
263                    "Generic arg mismatch: `{:?}` and `{:?}`.",
264                    garg0.debug(db),
265                    garg1.debug(db)
266                )
267            }
268            InferenceError::TraitMismatch { trt0, trt1 } => {
269                format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
270            }
271            InferenceError::ConstNotInferred => "Failed to infer constant.".into(),
272            InferenceError::NoImplsFound(concrete_trait_id) => {
273                if concrete_trait_id.trait_id(db) == db.core_info().string_literal_trt {
274                    let generic_type = extract_matches!(
275                        concrete_trait_id.generic_args(db)[0],
276                        GenericArgumentId::Type
277                    );
278                    return format!(
279                        "Mismatched types. The type `{:?}` cannot be created from a string \
280                         literal.",
281                        generic_type.debug(db)
282                    );
283                }
284                format!(
285                    "Trait has no implementation in context: {:?}.",
286                    concrete_trait_id.debug(db)
287                )
288            }
289            InferenceError::NoNegativeImplsFound(concrete_trait_id) => {
290                format!("Trait has implementation in context: {:?}.", concrete_trait_id.debug(db))
291            }
292            InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
293            InferenceError::TypeNotInferred(ty) => {
294                format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
295            }
296            InferenceError::GenericFunctionMismatch { func0, func1 } => {
297                format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
298            }
299            InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 } => {
300                format!(
301                    "`{}::{}` type mismatch: `{:?}` and `{:?}`.",
302                    impl_id.format(db),
303                    trait_type_id.name(db).long(db),
304                    ty0.debug(db),
305                    ty1.debug(db)
306                )
307            }
308            InferenceError::InvalidLiteralType(ty) => format!(
309                "Mismatched types. The type `{:?}` cannot be created from a numeric literal.",
310                ty.debug(db),
311            ),
312        }
313    }
314}
315
316impl<'db> InferenceError<'db> {
317    pub fn report(
318        &self,
319        diagnostics: &mut SemanticDiagnostics<'db>,
320        stable_ptr: SyntaxStablePtrId<'db>,
321    ) -> DiagnosticAdded {
322        match self {
323            InferenceError::Reported(diagnostic_added) => *diagnostic_added,
324            _ => diagnostics
325                .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
326        }
327    }
328}
329
330/// This struct is used to ensure that when an inference error occurs, it is properly set in the
331/// `Inference` object, and then properly consumed.
332///
333/// It must not be constructed directly. Instead, it is returned by [Inference::set_error].
334#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
335pub struct ErrorSet;
336
337pub type InferenceResult<T> = Result<T, ErrorSet>;
338
339#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
340enum InferenceErrorStatus<'db> {
341    /// There is a pending error.
342    Pending(PendingInferenceError<'db>),
343    /// There was an error but it was already consumed.
344    Consumed(DiagnosticAdded),
345}
346
347/// A pending inference error.
348#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
349struct PendingInferenceError<'db> {
350    /// The actual error.
351    err: InferenceError<'db>,
352    /// The optional location of the error.
353    stable_ptr: Option<SyntaxStablePtrId<'db>>,
354}
355
356/// A mapping of an impl var's trait items to concrete items.
357#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, SemanticObject, salsa::Update)]
358pub struct ImplVarTraitItemMappings<'db> {
359    /// The trait types of the impl var.
360    types: OrderedHashMap<TraitTypeId<'db>, TypeId<'db>>,
361    /// The trait constants of the impl var.
362    constants: OrderedHashMap<TraitConstantId<'db>, ConstValueId<'db>>,
363    /// The trait impls of the impl var.
364    impls: OrderedHashMap<TraitImplId<'db>, ImplId<'db>>,
365}
366
367impl ImplVarTraitItemMappings<'_> {
368    /// Returns `true` if the impl var has no mappings.
369    pub fn is_empty(&self) -> bool {
370        self.types.is_empty() && self.constants.is_empty() && self.impls.is_empty()
371    }
372}
373
374/// State of inference.
375#[derive(Debug, DebugWithDb, PartialEq, Eq, salsa::Update)]
376#[debug_db(dyn Database)]
377pub struct InferenceData<'db> {
378    pub inference_id: InferenceId<'db>,
379    /// Current inferred assignment for type variables.
380    pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId<'db>>,
381    /// Current inferred assignment for const variables.
382    pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId<'db>>,
383    /// Current inferred assignment for impl variables.
384    pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId<'db>>,
385    /// Current inferred assignment for negative impl variables.
386    pub negative_impl_assignment: OrderedHashMap<LocalNegativeImplVarId, NegativeImplId<'db>>,
387    /// Unsolved impl variables mapping to a maps of trait items to a corresponding item variable.
388    /// Upon solution of the trait conforms the fully known item to the variable.
389    pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings<'db>>,
390    /// Type variables.
391    pub type_vars: Vec<TypeVar<'db>>,
392    /// Type variables that represent deferred concrete integer types from numeric literals
393    /// (`TypeLongId::NumericLiteral`). At finalization, any of these that remain unassigned are
394    /// defaulted to `felt252`.
395    pub integer_literal_vars: Vec<LocalTypeVarId>,
396    /// Const variables.
397    pub const_vars: Vec<ConstVar<'db>>,
398    /// Impl variables.
399    pub impl_vars: Vec<ImplVar<'db>>,
400    /// Negative impl variables.
401    pub negative_impl_vars: Vec<NegativeImplVar<'db>>,
402    /// Mapping from variables to stable pointers, if exist.
403    pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId<'db>>,
404    /// Inference variables that are pending to be solved.
405    pending: Deque<LocalImplVarId>,
406    /// Inference negative variables that are pending to be solved.
407    negative_pending: Deque<LocalNegativeImplVarId>,
408    /// Inference variables that have been refuted - no solutions exist.
409    refuted: Vec<LocalImplVarId>,
410    /// Inference negative variables that have been refuted - no solutions exist.
411    negative_refuted: Vec<LocalNegativeImplVarId>,
412    /// Inference variables that have been solved.
413    solved: Vec<LocalImplVarId>,
414    /// Inference variables that are currently ambiguous. May be solved later.
415    ambiguous: Vec<(LocalImplVarId, Ambiguity<'db>)>,
416    /// Negative impl inference variables that are currently ambiguous. May be solved later.
417    negative_ambiguous: Vec<(LocalNegativeImplVarId, Ambiguity<'db>)>,
418    /// Mapping from impl types to type variables.
419    pub impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
420
421    /// The current error status.
422    error_status: Result<(), InferenceErrorStatus<'db>>,
423}
424impl<'db> InferenceData<'db> {
425    pub fn new(inference_id: InferenceId<'db>) -> Self {
426        Self {
427            inference_id,
428            type_assignment: OrderedHashMap::default(),
429            impl_assignment: OrderedHashMap::default(),
430            const_assignment: OrderedHashMap::default(),
431            negative_impl_assignment: OrderedHashMap::default(),
432            impl_vars_trait_item_mappings: HashMap::new(),
433            type_vars: Vec::new(),
434            integer_literal_vars: Vec::new(),
435            impl_vars: Vec::new(),
436            const_vars: Vec::new(),
437            negative_impl_vars: Vec::new(),
438            stable_ptrs: HashMap::new(),
439            pending: Deque::new(),
440            negative_pending: Deque::new(),
441            refuted: Vec::new(),
442            negative_refuted: Vec::new(),
443            solved: Vec::new(),
444            ambiguous: Vec::new(),
445            negative_ambiguous: Vec::new(),
446            impl_type_bounds: Default::default(),
447            error_status: Ok(()),
448        }
449    }
450    pub fn inference<'r>(&'r mut self, db: &'db dyn Database) -> Inference<'db, 'r> {
451        Inference::new(db, self)
452    }
453    pub fn clone_with_inference_id(
454        &self,
455        db: &'db dyn Database,
456        inference_id: InferenceId<'db>,
457    ) -> InferenceData<'db> {
458        let mut inference_id_replacer =
459            InferenceIdReplacer::new(db, self.inference_id, inference_id);
460        Self {
461            inference_id,
462            type_assignment: self
463                .type_assignment
464                .iter()
465                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
466                .collect(),
467            const_assignment: self
468                .const_assignment
469                .iter()
470                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
471                .collect(),
472            impl_assignment: self
473                .impl_assignment
474                .iter()
475                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
476                .collect(),
477            negative_impl_assignment: self
478                .negative_impl_assignment
479                .iter()
480                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
481                .collect(),
482            impl_vars_trait_item_mappings: self
483                .impl_vars_trait_item_mappings
484                .iter()
485                .map(|(k, mappings)| {
486                    (
487                        *k,
488                        ImplVarTraitItemMappings {
489                            types: mappings
490                                .types
491                                .iter()
492                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
493                                .collect(),
494                            constants: mappings
495                                .constants
496                                .iter()
497                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
498                                .collect(),
499                            impls: mappings
500                                .impls
501                                .iter()
502                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
503                                .collect(),
504                        },
505                    )
506                })
507                .collect(),
508            type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
509            integer_literal_vars: self.integer_literal_vars.clone(),
510            const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
511            impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
512            negative_impl_vars: inference_id_replacer
513                .rewrite(self.negative_impl_vars.clone())
514                .no_err(),
515            stable_ptrs: self.stable_ptrs.clone(),
516            pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
517            negative_pending: inference_id_replacer.rewrite(self.negative_pending.clone()).no_err(),
518            refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
519            negative_refuted: inference_id_replacer.rewrite(self.negative_refuted.clone()).no_err(),
520            solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
521            ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
522            negative_ambiguous: inference_id_replacer
523                .rewrite(self.negative_ambiguous.clone())
524                .no_err(),
525            // we do not need to rewrite the impl type bounds, as they all should be var free.
526            impl_type_bounds: self.impl_type_bounds.clone(),
527
528            error_status: self.error_status.clone(),
529        }
530    }
531    pub fn temporary_clone(&self) -> InferenceData<'db> {
532        Self {
533            inference_id: self.inference_id,
534            type_assignment: self.type_assignment.clone(),
535            const_assignment: self.const_assignment.clone(),
536            impl_assignment: self.impl_assignment.clone(),
537            negative_impl_assignment: self.negative_impl_assignment.clone(),
538            impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
539            type_vars: self.type_vars.clone(),
540            integer_literal_vars: self.integer_literal_vars.clone(),
541            const_vars: self.const_vars.clone(),
542            impl_vars: self.impl_vars.clone(),
543            negative_impl_vars: self.negative_impl_vars.clone(),
544            stable_ptrs: self.stable_ptrs.clone(),
545            pending: self.pending.clone(),
546            negative_pending: self.negative_pending.clone(),
547            refuted: self.refuted.clone(),
548            negative_refuted: self.negative_refuted.clone(),
549            solved: self.solved.clone(),
550            ambiguous: self.ambiguous.clone(),
551            negative_ambiguous: self.negative_ambiguous.clone(),
552            impl_type_bounds: self.impl_type_bounds.clone(),
553            error_status: self.error_status.clone(),
554        }
555    }
556}
557
558/// State of inference. A system of inference constraints.
559pub struct Inference<'db, 'id> {
560    db: &'db dyn Database,
561    pub data: &'id mut InferenceData<'db>,
562}
563
564impl<'db, 'id> Deref for Inference<'db, 'id> {
565    type Target = InferenceData<'db>;
566
567    fn deref(&self) -> &Self::Target {
568        self.data
569    }
570}
571impl DerefMut for Inference<'_, '_> {
572    fn deref_mut(&mut self) -> &mut Self::Target {
573        self.data
574    }
575}
576
577impl std::fmt::Debug for Inference<'_, '_> {
578    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
579        let x = self.data.debug(self.db);
580        write!(f, "{x:?}")
581    }
582}
583
584impl<'db, 'id> Inference<'db, 'id> {
585    fn new(db: &'db dyn Database, data: &'id mut InferenceData<'db>) -> Self {
586        Self { db, data }
587    }
588
589    /// Getter for an [ImplVar].
590    fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar<'db> {
591        &self.impl_vars[var_id.0]
592    }
593
594    /// Getter for an impl var assignment.
595    pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId<'db>> {
596        self.impl_assignment.get(&var_id).copied()
597    }
598
599    /// Getter for a [NegativeImplVar].
600    fn negative_impl_var(&self, var_id: LocalNegativeImplVarId) -> &NegativeImplVar<'db> {
601        &self.negative_impl_vars[var_id.0]
602    }
603
604    /// Getter for a negative impl var assignment.
605    pub fn negative_impl_assignment(
606        &self,
607        var_id: LocalNegativeImplVarId,
608    ) -> Option<NegativeImplId<'db>> {
609        self.negative_impl_assignment.get(&var_id).copied()
610    }
611
612    /// Getter for a type var assignment.
613    fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId<'db>> {
614        self.type_assignment.get(&var_id).copied()
615    }
616
617    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
618    /// Returns a wrapping TypeId.
619    pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId<'db>>) -> TypeId<'db> {
620        let var = self.new_type_var_raw(stable_ptr);
621
622        TypeLongId::Var(var).intern(self.db)
623    }
624
625    /// Allocates a new [TypeVar] wrapped in a [`TypeLongId::NumericLiteral`] placeholder for a
626    /// numeric literal whose type is not yet known. The synthetic-impl machinery treats this
627    /// placeholder as satisfying `Drop`/`Copy`/`Destruct`/`PanicDestruct`, and finalization
628    /// defaults any still-unbound variant to `felt252`.
629    pub fn new_integer_literal_type(
630        &mut self,
631        stable_ptr: Option<SyntaxStablePtrId<'db>>,
632    ) -> TypeId<'db> {
633        let var = self.new_type_var_raw(stable_ptr);
634        self.data.integer_literal_vars.push(var.id);
635        TypeLongId::NumericLiteral(var).intern(self.db)
636    }
637
638    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
639    /// Returns the variable id.
640    pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId<'db>>) -> TypeVar<'db> {
641        let var =
642            TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
643        if let Some(stable_ptr) = stable_ptr {
644            self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
645        }
646        self.type_vars.push(var);
647        var
648    }
649
650    /// Sets the inference's impl type bounds to the given map, and rewrittes the types so all the
651    /// types are var free.
652    pub fn set_impl_type_bounds(
653        &mut self,
654        impl_type_bounds: OrderedHashMap<ImplTypeId<'db>, TypeId<'db>>,
655    ) {
656        let impl_type_bounds_finalized = impl_type_bounds
657            .iter()
658            .filter_map(|(impl_type, ty)| {
659                let rewritten_type = self.rewrite(ty.long(self.db).clone()).no_err();
660                if !matches!(rewritten_type, TypeLongId::Var(_) | TypeLongId::NumericLiteral(_)) {
661                    return Some(((*impl_type).into(), rewritten_type.intern(self.db)));
662                }
663                // conformed the var type to the original impl type to remove it from the pending
664                // list.
665                self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
666                None
667            })
668            .collect();
669
670        self.data.impl_type_bounds = Arc::new(impl_type_bounds_finalized);
671    }
672
673    /// Allocates a new [ConstVar] for an unknown const that needs to be inferred.
674    /// Returns a wrapping [ConstValueId].
675    pub fn new_const_var(
676        &mut self,
677        stable_ptr: Option<SyntaxStablePtrId<'db>>,
678        ty: TypeId<'db>,
679    ) -> ConstValueId<'db> {
680        let var = self.new_const_var_raw(stable_ptr);
681        ConstValue::Var(var, ty).intern(self.db)
682    }
683
684    /// Allocates a new [ConstVar] for an unknown type that needs to be inferred.
685    /// Returns the variable id.
686    pub fn new_const_var_raw(
687        &mut self,
688        stable_ptr: Option<SyntaxStablePtrId<'db>>,
689    ) -> ConstVar<'db> {
690        let var = ConstVar {
691            inference_id: self.inference_id,
692            id: LocalConstVarId(self.const_vars.len()),
693        };
694        if let Some(stable_ptr) = stable_ptr {
695            self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
696        }
697        self.const_vars.push(var);
698        var
699    }
700
701    /// Allocates a new [ImplVar] for an unknown type that needs to be inferred.
702    /// Returns a wrapping ImplId.
703    pub fn new_impl_var(
704        &mut self,
705        concrete_trait_id: ConcreteTraitId<'db>,
706        stable_ptr: Option<SyntaxStablePtrId<'db>>,
707        lookup_context: ImplLookupContextId<'db>,
708    ) -> ImplId<'db> {
709        let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
710        ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
711    }
712
713    /// Allocates a new [ImplVar] for an unknown impl that needs to be inferred.
714    /// Returns the variable id.
715    fn new_impl_var_raw(
716        &mut self,
717        lookup_context: ImplLookupContextId<'db>,
718        concrete_trait_id: ConcreteTraitId<'db>,
719        stable_ptr: Option<SyntaxStablePtrId<'db>>,
720    ) -> LocalImplVarId {
721        let id = LocalImplVarId(self.impl_vars.len());
722        if let Some(stable_ptr) = stable_ptr {
723            self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
724        }
725        let var =
726            ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
727        self.impl_vars.push(var);
728        self.pending.push_back(id);
729        id
730    }
731
732    /// Allocates a new [NegativeImplVar] for an unknown negative impl that needs to be inferred.
733    /// Returns a wrapping NegativeImplId.
734    pub fn new_negative_impl_var(
735        &mut self,
736        concrete_trait_id: ConcreteTraitId<'db>,
737        stable_ptr: Option<SyntaxStablePtrId<'db>>,
738        lookup_context: ImplLookupContextId<'db>,
739    ) -> NegativeImplId<'db> {
740        let var = self.new_negative_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
741        NegativeImplLongId::NegativeImplVar(self.negative_impl_var(var).clone().intern(self.db))
742            .intern(self.db)
743    }
744
745    /// Allocates a new [NegativeImplVar] for an unknown type that needs to be inferred.
746    /// Returns the variable id.
747    fn new_negative_impl_var_raw(
748        &mut self,
749        lookup_context: ImplLookupContextId<'db>,
750        concrete_trait_id: ConcreteTraitId<'db>,
751        stable_ptr: Option<SyntaxStablePtrId<'db>>,
752    ) -> LocalNegativeImplVarId {
753        let id = LocalNegativeImplVarId(self.negative_impl_vars.len());
754        if let Some(stable_ptr) = stable_ptr {
755            self.stable_ptrs.insert(InferenceVar::NegativeImpl(id), stable_ptr);
756        }
757        let var = NegativeImplVar {
758            inference_id: self.inference_id,
759            id,
760            concrete_trait_id,
761            lookup_context,
762        };
763        self.negative_impl_vars.push(var);
764        self.negative_pending.push_back(id);
765        id
766    }
767
768    /// Solves the inference system. After a successful solve, there are no more pending impl
769    /// inferences.
770    /// Returns whether the inference was successful. If not, the error may be found by
771    /// `.error_state()`.
772    pub fn solve(&mut self) -> InferenceResult<()> {
773        let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
774        self.pending.extend(ambiguous.map(|(var, _)| var));
775        while let Some(var) = self.pending.pop_front() {
776            // First inference error stops inference.
777            self.solve_single_pending(var).inspect_err(|_err_set| {
778                self.add_error_stable_ptr(InferenceVar::Impl(var));
779            })?;
780        }
781        while let Some(var) = self.negative_pending.pop_front() {
782            // First inference error stops inference.
783            self.solve_single_negative_pending(var).inspect_err(|_err_set| {
784                self.add_error_stable_ptr(InferenceVar::NegativeImpl(var));
785            })?;
786        }
787        Ok(())
788    }
789
790    fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
791        if self.impl_assignment.contains_key(&var) {
792            return Ok(());
793        }
794        let solution = match self.impl_var_solution_set(var)? {
795            SolutionSet::None => {
796                self.refuted.push(var);
797                return Ok(());
798            }
799            SolutionSet::Ambiguous(ambiguity) => {
800                self.ambiguous.push((var, ambiguity));
801                return Ok(());
802            }
803            SolutionSet::Unique(solution) => solution,
804        };
805
806        // Solution found. Assign it.
807        self.assign_local_impl(var, solution)?;
808
809        // Something changed.
810        self.solved.push(var);
811        let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
812        self.pending.extend(ambiguous.map(|(var, _)| var));
813
814        let negative_ambiguous = std::mem::take(&mut self.negative_ambiguous).into_iter();
815        self.negative_pending.extend(negative_ambiguous.map(|(var, _)| var));
816
817        Ok(())
818    }
819
820    /// Solves a single negative impl pending variable.
821    fn solve_single_negative_pending(
822        &mut self,
823        var: LocalNegativeImplVarId,
824    ) -> InferenceResult<()> {
825        if self.negative_impl_assignment.contains_key(&var) {
826            return Ok(());
827        }
828
829        let solution = match self.negative_impl_var_solution_set(var)? {
830            SolutionSet::None => {
831                self.negative_refuted.push(var);
832                return Ok(());
833            }
834            SolutionSet::Ambiguous(ambiguity) => {
835                self.negative_ambiguous.push((var, ambiguity));
836                return Ok(());
837            }
838            SolutionSet::Unique(solution) => solution,
839        };
840
841        // Solution found. Assign it.
842        self.assign_local_negative_impl(var, solution)?;
843
844        Ok(())
845    }
846
847    /// Returns the solution set status for the inference:
848    /// Whether there is a unique solution, multiple solutions, no solutions or an error.
849    pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<'db, ()>> {
850        self.solve()?;
851        if !self.refuted.is_empty() {
852            return Ok(SolutionSet::None);
853        }
854        if !self.negative_refuted.is_empty() {
855            return Ok(SolutionSet::None);
856        }
857        if let Some((_, ambiguity)) = self.ambiguous.first() {
858            return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
859        }
860        if let Some((_, ambiguity)) = self.negative_ambiguous.first() {
861            return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
862        }
863        assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
864        assert!(self.negative_pending.is_empty(), "solution() called on an unsolved solver");
865        Ok(SolutionSet::Unique(()))
866    }
867
868    /// Finalizes the inference by inferring uninferred numeric literals as felt252.
869    /// Returns an error and does not report it.
870    pub fn finalize_without_reporting(&mut self) -> Result<(), ErrorSet> {
871        if self.error_status.is_err() {
872            return Err(ErrorSet);
873        }
874        let felt_ty = self.db.core_info().felt252;
875
876        // Default any still-unbound `NumericLiteral` type variables to `felt252`. We do this one
877        // var at a time and re-solve between each, mirroring the old behavior for
878        // `NumericLiteral<Var>`: defaulting the first literal can constrain other literals
879        // through trait resolution (e.g., `0.pow(0)` — defaulting the receiver lets `Pow<T, Exp>`
880        // resolve, which then constrains the exponent to `usize`, avoiding a felt252 default for
881        // the second literal).
882        loop {
883            self.solve()?;
884            let mut changed = false;
885            for var_id in self.data.integer_literal_vars.clone() {
886                if self.type_assignment.contains_key(&var_id) {
887                    continue;
888                }
889                let var = self.type_vars[var_id.0];
890                let placeholder = TypeLongId::NumericLiteral(var).intern(self.db);
891                if self.rewrite(placeholder).no_err() == felt_ty {
892                    continue;
893                }
894                self.conform_ty(placeholder, felt_ty).inspect_err(|_err_set| {
895                    self.add_error_stable_ptr(InferenceVar::Type(var_id));
896                })?;
897                changed = true;
898                break;
899            }
900            if !changed {
901                break;
902            }
903        }
904        assert!(
905            self.pending.is_empty(),
906            "pending should all be solved by this point. Guaranteed by solve()."
907        );
908
909        // Defensive backstop: every invalid `IntegerLiteral` binding should already have been
910        // caught at conform-time by `bind_integer_literal`.
911        for var_id in self.data.integer_literal_vars.clone() {
912            let Some(bound) = self.type_assignment.get(&var_id).copied() else {
913                continue;
914            };
915            let bound = self.rewrite(bound).no_err();
916            if !bound.is_var_free(self.db) || is_valid_literal_type(self.db, bound) {
917                continue;
918            }
919            let err = self.set_error(InferenceError::InvalidLiteralType(bound));
920            self.add_error_stable_ptr(InferenceVar::Type(var_id));
921            return Err(err);
922        }
923
924        // Persist rewriter substitutions of synthetic `NumericLiteral`-over-mem-trait
925        // `GeneratedImpl`s into `impl_assignment` so subsequent queries see the natural impl.
926        let assignments: Vec<_> = self.impl_assignment.iter().map(|(k, v)| (*k, *v)).collect();
927        for (var_id, impl_id) in assignments {
928            let rewritten = self.rewrite(impl_id).no_err();
929            if rewritten != impl_id {
930                self.impl_assignment.insert(var_id, rewritten);
931            }
932        }
933
934        let Some((var, err)) = self.first_undetermined_variable() else {
935            return Ok(());
936        };
937        Err(self.set_error_on_var(err, var))
938    }
939
940    /// Finalizes the inference and report diagnostics if there are any errors.
941    /// All the remaining type vars are mapped to the `missing` type, to prevent additional
942    /// diagnostics.
943    pub fn finalize<'m>(
944        &'m mut self,
945        diagnostics: &mut SemanticDiagnostics<'db>,
946        stable_ptr: SyntaxStablePtrId<'db>,
947    ) {
948        if let Err(err_set) = self.finalize_without_reporting() {
949            let diag = self.report_on_pending_error(err_set, diagnostics, stable_ptr);
950
951            let ty_missing = TypeId::missing(self.db, diag);
952            for var in &self.data.type_vars {
953                self.data.type_assignment.entry(var.id).or_insert(ty_missing);
954            }
955        }
956    }
957
958    /// Retrieves the first variable that is still not inferred, or None, if everything is
959    /// inferred.
960    /// Does not set the error but return it, which is ok as this is a private helper function.
961    fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError<'db>)> {
962        if let Some(var) = self.refuted.first().copied() {
963            let impl_var = self.impl_var(var).clone();
964            let concrete_trait_id = impl_var.concrete_trait_id;
965            let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
966            return Some((
967                InferenceVar::Impl(var),
968                InferenceError::NoImplsFound(concrete_trait_id),
969            ));
970        }
971        if let Some(var) = self.negative_refuted.first().copied() {
972            let negative_impl_var = self.negative_impl_var(var).clone();
973            let concrete_trait_id = negative_impl_var.concrete_trait_id;
974            let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
975            return Some((
976                InferenceVar::NegativeImpl(var),
977                InferenceError::NoNegativeImplsFound(concrete_trait_id),
978            ));
979        }
980
981        let mut fallback_ret = None;
982        if let Some((var, ambiguity)) = self.ambiguous.first() {
983            // Note: do not rewrite `ambiguity`, since it is expressed in canonical variables.
984            let ret =
985                Some((InferenceVar::Impl(*var), InferenceError::Ambiguity(ambiguity.clone())));
986            if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
987                return ret;
988            } else {
989                fallback_ret = ret;
990            }
991        }
992        if let Some((var, ambiguity)) = self.negative_ambiguous.first() {
993            let ret = Some((
994                InferenceVar::NegativeImpl(*var),
995                InferenceError::Ambiguity(ambiguity.clone()),
996            ));
997            if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
998                return ret;
999            } else {
1000                fallback_ret = ret;
1001            }
1002        }
1003        for (id, var) in self.type_vars.iter().enumerate() {
1004            if self.type_assignment(LocalTypeVarId(id)).is_none() {
1005                let ty = TypeLongId::Var(*var).intern(self.db);
1006                return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
1007            }
1008        }
1009        for (id, var) in self.const_vars.iter().enumerate() {
1010            if !self.const_assignment.contains_key(&LocalConstVarId(id)) {
1011                let infernence_var = InferenceVar::Const(var.id);
1012                return Some((infernence_var, InferenceError::ConstNotInferred));
1013            }
1014        }
1015        fallback_ret
1016    }
1017
1018    /// Assigns a value to a local impl variable id. See assign_impl().
1019    fn assign_local_impl(
1020        &mut self,
1021        var: LocalImplVarId,
1022        impl_id: ImplId<'db>,
1023    ) -> InferenceResult<ImplId<'db>> {
1024        let concrete_trait = impl_id
1025            .concrete_trait(self.db)
1026            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1027        self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
1028        if let Some(other_impl) = self.impl_assignment(var) {
1029            return self.conform_impl(impl_id, other_impl);
1030        }
1031        if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
1032        {
1033            let inference_var = InferenceVar::Impl(var);
1034            return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var));
1035        }
1036        self.impl_assignment.insert(var, impl_id);
1037        if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
1038            for (trait_type_id, ty) in mappings.types {
1039                let impl_ty = self
1040                    .db
1041                    .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_type_id, self.db))
1042                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1043                if let Err(err_set) = self.conform_ty(ty, impl_ty) {
1044                    // Override the error with ImplTypeMismatch.
1045                    let ty0 = self.rewrite(ty).no_err();
1046                    let ty1 = self.rewrite(impl_ty).no_err();
1047
1048                    let err = InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 };
1049                    self.error_status = Err(InferenceErrorStatus::Pending(PendingInferenceError {
1050                        err,
1051                        stable_ptr: self.stable_ptrs.get(&InferenceVar::Impl(var)).cloned(),
1052                    }));
1053                    return Err(err_set);
1054                }
1055            }
1056            for (trait_constant, constant_id) in mappings.constants {
1057                let concrete_impl_constant = self
1058                    .db
1059                    .impl_constant_concrete_implized_value(ImplConstantId::new(
1060                        impl_id,
1061                        trait_constant,
1062                        self.db,
1063                    ))
1064                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1065                self.conform_const(constant_id, concrete_impl_constant)?;
1066            }
1067            for (trait_impl, inner_impl_id) in mappings.impls {
1068                let concrete_impl_impl = self
1069                    .db
1070                    .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
1071                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1072                self.conform_impl(inner_impl_id, concrete_impl_impl)?;
1073            }
1074        }
1075        Ok(impl_id)
1076    }
1077
1078    /// Tries to assigns value to an [ImplVarId]. Return the assigned impl, or an error.
1079    fn assign_impl(
1080        &mut self,
1081        var_id: ImplVarId<'db>,
1082        impl_id: ImplId<'db>,
1083    ) -> InferenceResult<ImplId<'db>> {
1084        let var = var_id.long(self.db);
1085        if var.inference_id != self.inference_id {
1086            return Err(self.set_error(InferenceError::ImplKindMismatch {
1087                impl0: ImplLongId::ImplVar(var_id).intern(self.db),
1088                impl1: impl_id,
1089            }));
1090        }
1091        self.assign_local_impl(var.id, impl_id)
1092    }
1093
1094    /// Assigns a value to a local negative impl variable id. See assign_neg_impl().
1095    fn assign_local_negative_impl(
1096        &mut self,
1097        var: LocalNegativeImplVarId,
1098        neg_impl_id: NegativeImplId<'db>,
1099    ) -> InferenceResult<NegativeImplId<'db>> {
1100        let concrete_trait = neg_impl_id
1101            .concrete_trait(self.db)
1102            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1103        self.conform_traits(self.negative_impl_var(var).concrete_trait_id, concrete_trait)?;
1104        if let Some(other_impl) = self.negative_impl_assignment(var) {
1105            return self.conform_neg_impl(neg_impl_id, other_impl);
1106        }
1107        if !neg_impl_id.is_var_free(self.db)
1108            && self.negative_impl_contains_var(neg_impl_id, InferenceVar::NegativeImpl(var))
1109        {
1110            return Err(self.set_error(InferenceError::Cycle(InferenceVar::NegativeImpl(var))));
1111        }
1112        self.negative_impl_assignment.insert(var, neg_impl_id);
1113        Ok(neg_impl_id)
1114    }
1115
1116    /// Tries to assigns value to an [NegativeImplVarId]. Return the assigned negative impl, or an
1117    /// error.
1118    fn assign_neg_impl(
1119        &mut self,
1120        var_id: NegativeImplVarId<'db>,
1121        neg_impl_id: NegativeImplId<'db>,
1122    ) -> InferenceResult<NegativeImplId<'db>> {
1123        let var = var_id.long(self.db);
1124        if var.inference_id != self.inference_id {
1125            return Err(self.set_error(InferenceError::NegativeImplKindMismatch {
1126                impl0: NegativeImplLongId::NegativeImplVar(var_id).intern(self.db),
1127                impl1: neg_impl_id,
1128            }));
1129        }
1130        self.assign_local_negative_impl(var.id, neg_impl_id)
1131    }
1132
1133    /// Assigns a value to a [TypeVar]. Return the assigned type, or an error.
1134    /// Assumes the variable is not already assigned.
1135    fn assign_ty(&mut self, var: TypeVar<'db>, ty: TypeId<'db>) -> InferenceResult<TypeId<'db>> {
1136        if var.inference_id != self.inference_id {
1137            return Err(self.set_error(InferenceError::TypeKindMismatch {
1138                ty0: TypeLongId::Var(var).intern(self.db),
1139                ty1: ty,
1140            }));
1141        }
1142        assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
1143        let inference_var = InferenceVar::Type(var.id);
1144        if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
1145            return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var));
1146        }
1147        // If assigning var to var - making sure assigning to the lower id for proper canonization.
1148        if let TypeLongId::Var(other) = ty.long(self.db)
1149            && other.inference_id == self.inference_id
1150            && other.id.0 > var.id.0
1151        {
1152            let var_ty = TypeLongId::Var(var).intern(self.db);
1153            self.type_assignment.insert(other.id, var_ty);
1154            return Ok(var_ty);
1155        }
1156        self.type_assignment.insert(var.id, ty);
1157        Ok(ty)
1158    }
1159
1160    /// Assigns a value to a [ConstVar]. Return the assigned const, or an error.
1161    /// Assumes the variable is not already assigned.
1162    fn assign_const(
1163        &mut self,
1164        var: ConstVar<'db>,
1165        id: ConstValueId<'db>,
1166    ) -> InferenceResult<ConstValueId<'db>> {
1167        if var.inference_id != self.inference_id {
1168            return Err(self.set_error(InferenceError::ConstKindMismatch {
1169                const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
1170                    .intern(self.db),
1171                const1: id,
1172            }));
1173        }
1174
1175        self.const_assignment.insert(var.id, id);
1176        Ok(id)
1177    }
1178
1179    /// Computes the solution set for an impl variable with a recursive query.
1180    fn impl_var_solution_set(
1181        &mut self,
1182        var: LocalImplVarId,
1183    ) -> InferenceResult<SolutionSet<'db, ImplId<'db>>> {
1184        let impl_var = self.impl_var(var).clone();
1185        // Update the concrete trait of the impl var.
1186        let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
1187        self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
1188        let impl_var_trait_item_mappings =
1189            self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
1190        let solution_set = self.trait_solution_set(
1191            concrete_trait_id,
1192            impl_var_trait_item_mappings,
1193            impl_var.lookup_context,
1194        )?;
1195        Ok(match solution_set {
1196            SolutionSet::None => SolutionSet::None,
1197            SolutionSet::Unique((canonical_impl, canonicalizer)) => {
1198                SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
1199            }
1200            SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
1201        })
1202    }
1203
1204    /// Computes the solution set for a negative impl variable with a recursive query.
1205    fn negative_impl_var_solution_set(
1206        &mut self,
1207        var: LocalNegativeImplVarId,
1208    ) -> InferenceResult<SolutionSet<'db, NegativeImplId<'db>>> {
1209        let negative_impl_var = self.negative_impl_var(var).clone();
1210        let concrete_trait_id = self.rewrite(negative_impl_var.concrete_trait_id).no_err();
1211
1212        let solution_set =
1213            self.validate_no_solution_set(concrete_trait_id, negative_impl_var.lookup_context)?;
1214        Ok(match solution_set {
1215            SolutionSet::Unique(concrete_trait_id) => {
1216                SolutionSet::Unique(NegativeImplLongId::Solved(concrete_trait_id).intern(self.db))
1217            }
1218            SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
1219            SolutionSet::None => SolutionSet::None,
1220        })
1221    }
1222
1223    /// Validates that no solution set is found for the negative impls.
1224    fn validate_no_solution_set(
1225        &mut self,
1226        concrete_trait_id: ConcreteTraitId<'db>,
1227        lookup_context: ImplLookupContextId<'db>,
1228    ) -> InferenceResult<SolutionSet<'db, ConcreteTraitId<'db>>> {
1229        for negative_impl in &lookup_context.long(self.db).negative_impls {
1230            let generic_param = self
1231                .db
1232                .generic_param_semantic(*negative_impl)
1233                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1234            if let GenericParam::NegImpl(neg_impl) = generic_param
1235                && Ok(concrete_trait_id) == neg_impl.concrete_trait
1236            {
1237                return Ok(SolutionSet::Unique(concrete_trait_id));
1238            }
1239        }
1240
1241        let generic_args = concrete_trait_id.generic_args(self.db);
1242        // Defer if the generic argument contains inference variables, with special handling for
1243        // closures. Closures can only have one type even if not var-free, so they
1244        // don't cause ambiguity.
1245        if generic_args.iter().any(|garg| {
1246            matches!(
1247                garg,
1248                GenericArgumentId::Type(ty)
1249                if !matches!(
1250                    ty.long(self.db),
1251                    TypeLongId::Closure(_) | TypeLongId::NumericLiteral(_)
1252                )
1253                && !ty.is_var_free(self.db)
1254            )
1255        }) {
1256            return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1257        }
1258
1259        let mut neg_impl_generic_params = OrderedHashSet::default();
1260        let mut visited_types = OrderedHashSet::default();
1261        for garg in generic_args {
1262            if garg
1263                .extract_generic_params(self.db, &mut neg_impl_generic_params, &mut visited_types)
1264                .is_err()
1265            {
1266                return Ok(SolutionSet::Ambiguous(
1267                    Ambiguity::NegativeImplWithUnsupportedExtractedArgs(*garg),
1268                ));
1269            }
1270        }
1271
1272        let solution_set = if neg_impl_generic_params.is_empty() {
1273            self.trait_solution_set(
1274                concrete_trait_id,
1275                ImplVarTraitItemMappings::default(),
1276                lookup_context,
1277            )?
1278        } else {
1279            let mut substitution: OrderedHashMap<GenericParamId<'db>, GenericArgumentId<'db>> =
1280                Default::default();
1281
1282            for param in neg_impl_generic_params {
1283                let garg = match param.kind(self.db) {
1284                    GenericKind::Type => GenericArgumentId::Type(
1285                        self.new_type_var(Some(param.stable_ptr(self.db).untyped())),
1286                    ),
1287                    GenericKind::Const | GenericKind::Impl | GenericKind::NegImpl => {
1288                        return Ok(SolutionSet::Ambiguous(
1289                            Ambiguity::NegativeImplWithUnsupportedGenericParam(param),
1290                        ));
1291                    }
1292                };
1293
1294                substitution.insert(param, garg);
1295            }
1296            let rewritten_concrete_trait_id =
1297                GenericSubstitution { param_to_arg: substitution.clone(), self_impl: None }
1298                    .substitute(self.db, concrete_trait_id)
1299                    .unwrap();
1300
1301            let solution_set = self.trait_solution_set(
1302                rewritten_concrete_trait_id,
1303                ImplVarTraitItemMappings::default(),
1304                lookup_context,
1305            )?;
1306
1307            // Assign the newly created type variables with their corresponding generic parameters.
1308            // This step prevents leaving type variables unresolved.
1309            let db = self.db;
1310            for (generic_param, garg) in substitution {
1311                let GenericArgumentId::Type(ty) = garg else {
1312                    panic!("Expected a type variable");
1313                };
1314                let (TypeLongId::Var(var) | TypeLongId::NumericLiteral(var)) = ty.long(self.db)
1315                else {
1316                    panic!("Expected a type variable");
1317                };
1318                self.type_assignment
1319                    .entry(var.id)
1320                    .or_insert_with(|| TypeLongId::GenericParameter(generic_param).intern(db));
1321            }
1322
1323            solution_set
1324        };
1325
1326        if !matches!(solution_set, SolutionSet::None) {
1327            return Ok(SolutionSet::None);
1328        }
1329
1330        Ok(SolutionSet::Unique(concrete_trait_id))
1331    }
1332
1333    /// Computes the solution set for a trait with a recursive query.
1334    pub fn trait_solution_set(
1335        &mut self,
1336        concrete_trait_id: ConcreteTraitId<'db>,
1337        impl_var_trait_item_mappings: ImplVarTraitItemMappings<'db>,
1338        lookup_context_id: ImplLookupContextId<'db>,
1339    ) -> InferenceResult<SolutionSet<'db, (CanonicalImpl<'db>, CanonicalMapping<'db>)>> {
1340        let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
1341        // TODO(spapini): This is done twice. Consider doing it only here.
1342        let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
1343        let mut lookup_context = lookup_context_id.long(self.db).clone();
1344        enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
1345
1346        // Don't try to resolve impls if the first generic param is a variable.
1347        let generic_args = concrete_trait_id.generic_args(self.db);
1348        match generic_args.first() {
1349            Some(GenericArgumentId::Type(ty)) => {
1350                if let TypeLongId::Var(_) = ty.long(self.db) {
1351                    // Don't try to infer such impls.
1352                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1353                }
1354            }
1355            Some(GenericArgumentId::Impl(imp)) => {
1356                // Don't try to infer such impls.
1357                if let ImplLongId::ImplVar(_) = imp.long(self.db) {
1358                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1359                }
1360            }
1361            Some(GenericArgumentId::Constant(const_value)) => {
1362                if let ConstValue::Var(_, _) = const_value.long(self.db) {
1363                    // Don't try to infer such impls.
1364                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1365                }
1366            }
1367            _ => {}
1368        };
1369        let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
1370            self.db,
1371            self.inference_id,
1372            concrete_trait_id,
1373            impl_var_trait_item_mappings,
1374        );
1375        // impl_type_bounds order is deterimend by the generic params of the function and therefore
1376        // is consistent.
1377        let solution_set = match self.db.canonic_trait_solutions(
1378            canonical_trait,
1379            lookup_context.intern(self.db),
1380            (*self.data.impl_type_bounds).clone(),
1381        ) {
1382            Ok(solution_set) => solution_set,
1383            Err(err) => return Err(self.set_error(err)),
1384        };
1385        match solution_set {
1386            SolutionSet::None => Ok(SolutionSet::None),
1387            SolutionSet::Unique(canonical_impl) => {
1388                Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
1389            }
1390            SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
1391        }
1392    }
1393
1394    // Error handling methods
1395    // ======================
1396
1397    /// Sets an error in the inference state.
1398    /// Does nothing if an error is already set.
1399    /// Returns an `ErrorSet` that can be used in reporting the error.
1400    pub fn set_error(&mut self, err: InferenceError<'db>) -> ErrorSet {
1401        self.set_error_ex(err, None)
1402    }
1403
1404    /// Sets an error in the inference state, with an optional location for the diagnostics
1405    /// reporting. Does nothing if an error is already set.
1406    /// Returns an `ErrorSet` that can be used in reporting the error.
1407    pub fn set_error_ex(
1408        &mut self,
1409        err: InferenceError<'db>,
1410        stable_ptr: Option<SyntaxStablePtrId<'db>>,
1411    ) -> ErrorSet {
1412        if self.error_status.is_err() {
1413            return ErrorSet;
1414        }
1415        self.error_status = Err(if let InferenceError::Reported(diag_added) = err {
1416            InferenceErrorStatus::Consumed(diag_added)
1417        } else {
1418            InferenceErrorStatus::Pending(PendingInferenceError { err, stable_ptr })
1419        });
1420        ErrorSet
1421    }
1422
1423    /// Sets an error in the inference state, with a var to fetch location for the diagnostics
1424    /// reporting. Does nothing if an error is already set.
1425    /// Returns an `ErrorSet` that can be used in reporting the error.
1426    pub fn set_error_on_var(&mut self, err: InferenceError<'db>, var: InferenceVar) -> ErrorSet {
1427        self.set_error_ex(err, self.stable_ptrs.get(&var).cloned())
1428    }
1429
1430    /// Returns whether an error is set (either pending or consumed).
1431    pub fn is_error_set(&self) -> InferenceResult<()> {
1432        self.error_status.as_ref().copied().map_err(|_| ErrorSet)
1433    }
1434
1435    /// If there is no stable ptr for the pending error, add it by the given var.
1436    fn add_error_stable_ptr(&mut self, var: InferenceVar) {
1437        let var_stable_ptr = self.stable_ptrs.get(&var).copied();
1438        if let Err(InferenceErrorStatus::Pending(PendingInferenceError { err: _, stable_ptr })) =
1439            &mut self.error_status
1440            && stable_ptr.is_none()
1441        {
1442            *stable_ptr = var_stable_ptr;
1443        }
1444    }
1445
1446    /// Consumes the error but doesn't report it. If there is no error, or the error is consumed,
1447    /// returns None. This should be used with caution. Always prefer to use
1448    /// (1) `report_on_pending_error` if possible, or (2) `consume_reported_error` which is safer.
1449    ///
1450    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1451    pub fn consume_error_without_reporting(
1452        &mut self,
1453        err_set: ErrorSet,
1454    ) -> Option<InferenceError<'db>> {
1455        Some(self.consume_error_inner(err_set, skip_diagnostic())?.err)
1456    }
1457
1458    /// Consumes the error that is already reported. If there is no error, or the error is consumed,
1459    /// does nothing. This should be used with caution. Always prefer to use
1460    /// `report_on_pending_error` if possible.
1461    ///
1462    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1463    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1464    pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1465        self.consume_error_inner(err_set, diag_added);
1466    }
1467
1468    /// Consumes the error and returns it, but doesn't report it. If there is no error, or the error
1469    /// is already consumed, returns None. This should be used with caution. Always prefer to use
1470    /// `report_on_pending_error` if possible.
1471    ///
1472    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1473    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1474    fn consume_error_inner(
1475        &mut self,
1476        _err_set: ErrorSet,
1477        diag_added: DiagnosticAdded,
1478    ) -> Option<PendingInferenceError<'db>> {
1479        match &mut self.error_status {
1480            Err(InferenceErrorStatus::Pending(error)) => {
1481                let pending_error = std::mem::replace(
1482                    error,
1483                    PendingInferenceError {
1484                        err: InferenceError::Reported(diag_added),
1485                        stable_ptr: None,
1486                    },
1487                );
1488                self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
1489                Some(pending_error)
1490            }
1491            // TODO(orizi): `panic!("consume_error when there is no pending error")` instead.
1492            _ => None,
1493        }
1494    }
1495
1496    /// Consumes the pending error, if any, and reports it.
1497    /// Should only be called when an error is set, otherwise it panics.
1498    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1499    /// If an error was set but it's already consumed, it doesn't report it again but returns the
1500    /// stored `DiagnosticAdded`.
1501    pub fn report_on_pending_error(
1502        &mut self,
1503        _err_set: ErrorSet,
1504        diagnostics: &mut SemanticDiagnostics<'db>,
1505        stable_ptr: SyntaxStablePtrId<'db>,
1506    ) -> DiagnosticAdded {
1507        let Err(state_error) = &self.error_status else {
1508            panic!("report_on_pending_error should be called only on error");
1509        };
1510        match state_error {
1511            InferenceErrorStatus::Consumed(diag_added) => *diag_added,
1512            InferenceErrorStatus::Pending(pending) => {
1513                let diag_added = match &pending.err {
1514                    InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1515                        // If we have other diagnostics, there is no need to TypeNotInferred.
1516
1517                        // Note that `diagnostics` is not empty, so it is safe to return
1518                        // 'DiagnosticAdded' here.
1519                        skip_diagnostic()
1520                    }
1521                    diag => diag.report(diagnostics, pending.stable_ptr.unwrap_or(stable_ptr)),
1522                };
1523                self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
1524                diag_added
1525            }
1526        }
1527    }
1528
1529    /// If the current status is of a pending error, reports an alternative diagnostic, by calling
1530    /// `report`, and consumes the error. Otherwise, does nothing.
1531    pub fn report_modified_if_pending(
1532        &mut self,
1533        err_set: ErrorSet,
1534        report: impl FnOnce() -> DiagnosticAdded,
1535    ) {
1536        if matches!(self.error_status, Err(InferenceErrorStatus::Pending { .. })) {
1537            self.consume_reported_error(err_set, report());
1538        }
1539    }
1540}
1541
1542impl<'a, 'mt> HasDb<&'a dyn Database> for Inference<'a, 'mt> {
1543    fn get_db(&self) -> &'a dyn Database {
1544        self.db
1545    }
1546}
1547add_basic_rewrites!(<'a, 'mt>, Inference<'a, 'mt>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue NegativeImplLongId NegativeImplId);
1548add_expr_rewrites!(<'a, 'mt>, Inference<'a, 'mt>, NoError, @exclude);
1549add_rewrite!(<'a, 'mt>, Inference<'a, 'mt>, NoError, Ambiguity<'a>);
1550impl<'db, 'mt> SemanticRewriter<TypeId<'db>, NoError> for Inference<'db, 'mt> {
1551    fn internal_rewrite(&mut self, value: &mut TypeId<'db>) -> Result<RewriteResult, NoError> {
1552        if value.is_var_free(self.db) {
1553            return Ok(RewriteResult::NoChange);
1554        }
1555        value.default_rewrite(self)
1556    }
1557}
1558impl<'db, 'mt> SemanticRewriter<ImplId<'db>, NoError> for Inference<'db, 'mt> {
1559    fn internal_rewrite(&mut self, value: &mut ImplId<'db>) -> Result<RewriteResult, NoError> {
1560        if value.is_var_free(self.db) {
1561            return Ok(RewriteResult::NoChange);
1562        }
1563        value.default_rewrite(self)
1564    }
1565}
1566impl<'db, 'mt> SemanticRewriter<NegativeImplId<'db>, NoError> for Inference<'db, 'mt> {
1567    fn internal_rewrite(
1568        &mut self,
1569        value: &mut NegativeImplId<'db>,
1570    ) -> Result<RewriteResult, NoError> {
1571        if value.is_var_free(self.db) {
1572            return Ok(RewriteResult::NoChange);
1573        }
1574        value.default_rewrite(self)
1575    }
1576}
1577
1578impl<'db, 'mt> SemanticRewriter<TypeLongId<'db>, NoError> for Inference<'db, 'mt> {
1579    fn internal_rewrite(&mut self, value: &mut TypeLongId<'db>) -> Result<RewriteResult, NoError> {
1580        match value {
1581            TypeLongId::Var(var) | TypeLongId::NumericLiteral(var) => {
1582                if let Some(type_id) = self.type_assignment.get(&var.id) {
1583                    let mut long_type_id = type_id.long(self.db).clone();
1584                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1585                        *self.type_assignment.get_mut(&var.id).unwrap() =
1586                            long_type_id.clone().intern(self.db);
1587                    }
1588                    *value = long_type_id;
1589                    return Ok(RewriteResult::Modified);
1590                }
1591            }
1592            TypeLongId::ImplType(impl_type_id) => {
1593                if let Some(type_id) = self.impl_type_bounds.get(&((*impl_type_id).into())) {
1594                    *value = type_id.long(self.db).clone();
1595                    self.internal_rewrite(value)?;
1596                    return Ok(RewriteResult::Modified);
1597                }
1598                let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1599                let impl_id = impl_type_id.impl_id();
1600                let trait_ty = impl_type_id.ty();
1601                return Ok(match impl_id.long(self.db) {
1602                    ImplLongId::GenericParameter(_)
1603                    | ImplLongId::SelfImpl(_)
1604                    | ImplLongId::ImplImpl(_) => impl_type_id_rewrite_result,
1605                    ImplLongId::Concrete(_) => {
1606                        if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1607                            impl_id, trait_ty, self.db,
1608                        )) {
1609                            *value = self.rewrite(ty).no_err().long(self.db).clone();
1610                            RewriteResult::Modified
1611                        } else {
1612                            impl_type_id_rewrite_result
1613                        }
1614                    }
1615                    ImplLongId::ImplVar(var) => {
1616                        *value = self.rewritten_impl_type(*var, trait_ty).long(self.db).clone();
1617                        return Ok(RewriteResult::Modified);
1618                    }
1619                    ImplLongId::GeneratedImpl(generated) => {
1620                        *value = self
1621                            .rewrite(
1622                                *generated
1623                                    .long(self.db)
1624                                    .impl_items
1625                                    .0
1626                                    .get(&impl_type_id.ty())
1627                                    .unwrap(),
1628                            )
1629                            .no_err()
1630                            .long(self.db)
1631                            .clone();
1632                        RewriteResult::Modified
1633                    }
1634                });
1635            }
1636            _ => {}
1637        }
1638        value.default_rewrite(self)
1639    }
1640}
1641impl<'db, 'mt> SemanticRewriter<ConstValue<'db>, NoError> for Inference<'db, 'mt> {
1642    fn internal_rewrite(&mut self, value: &mut ConstValue<'db>) -> Result<RewriteResult, NoError> {
1643        match value {
1644            ConstValue::Var(var, _) => {
1645                return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1646                    let mut const_value = const_value_id.long(self.db).clone();
1647                    if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1648                        *self.const_assignment.get_mut(&var.id).unwrap() =
1649                            const_value.clone().intern(self.db);
1650                    }
1651                    *value = const_value;
1652                    RewriteResult::Modified
1653                } else {
1654                    RewriteResult::NoChange
1655                });
1656            }
1657            ConstValue::ImplConstant(impl_constant_id) => {
1658                let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1659                let impl_id = impl_constant_id.impl_id();
1660                let trait_constant = impl_constant_id.trait_constant_id();
1661                return Ok(match impl_id.long(self.db) {
1662                    ImplLongId::GenericParameter(_)
1663                    | ImplLongId::SelfImpl(_)
1664                    | ImplLongId::GeneratedImpl(_)
1665                    | ImplLongId::ImplImpl(_) => impl_constant_id_rewrite_result,
1666                    ImplLongId::Concrete(_) => {
1667                        if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1668                            ImplConstantId::new(impl_id, trait_constant, self.db),
1669                        ) {
1670                            *value = self.rewrite(constant).no_err().long(self.db).clone();
1671                            RewriteResult::Modified
1672                        } else {
1673                            impl_constant_id_rewrite_result
1674                        }
1675                    }
1676                    ImplLongId::ImplVar(var) => {
1677                        *value = self
1678                            .rewritten_impl_constant(*var, trait_constant)
1679                            .long(self.db)
1680                            .clone();
1681                        return Ok(RewriteResult::Modified);
1682                    }
1683                });
1684            }
1685            _ => {}
1686        }
1687        value.default_rewrite(self)
1688    }
1689}
1690impl<'db, 'mt> SemanticRewriter<ImplLongId<'db>, NoError> for Inference<'db, 'mt> {
1691    fn internal_rewrite(&mut self, value: &mut ImplLongId<'db>) -> Result<RewriteResult, NoError> {
1692        match value {
1693            ImplLongId::ImplVar(var) => {
1694                let long_id = var.long(self.db);
1695                // Relax the candidates.
1696                let impl_var_id = long_id.id;
1697                if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1698                    let mut long_impl_id = impl_id.long(self.db).clone();
1699                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1700                        *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1701                            long_impl_id.clone().intern(self.db);
1702                    }
1703                    *value = long_impl_id;
1704                    return Ok(RewriteResult::Modified);
1705                }
1706            }
1707            ImplLongId::GeneratedImpl(generated) => {
1708                // Recurse into the generated impl's fields first so the inner type is rewritten.
1709                let inner_result = generated.default_rewrite(self)?;
1710                let concrete_trait = generated.concrete_trait(self.db);
1711                if let Some(GenericArgumentId::Type(first_ty)) =
1712                    concrete_trait.generic_args(self.db).first().copied()
1713                    && !matches!(
1714                        first_ty.long(self.db),
1715                        TypeLongId::Closure(_)
1716                            | TypeLongId::Var(_)
1717                            | TypeLongId::NumericLiteral(_)
1718                            | TypeLongId::ImplType(_)
1719                    )
1720                {
1721                    let info = self.db.core_info();
1722                    let trait_id = concrete_trait.trait_id(self.db);
1723                    if [info.drop_trt, info.copy_trt, info.destruct_trt, info.panic_destruct_trt]
1724                        .contains(&trait_id)
1725                    {
1726                        // The synthetic NumericLiteral candidate produced this GeneratedImpl;
1727                        // now that the inner type is concrete, replace it with the natural impl.
1728                        let lookup_context =
1729                            ImplLookupContext::new_from_type(first_ty, self.db).intern(self.db);
1730                        if let Ok(real_impl) =
1731                            get_impl_at_context(self.db, lookup_context, concrete_trait, None)
1732                        {
1733                            *value = real_impl.long(self.db).clone();
1734                            return Ok(RewriteResult::Modified);
1735                        }
1736                    }
1737                }
1738                return Ok(inner_result);
1739            }
1740            ImplLongId::ImplImpl(impl_impl_id) => {
1741                let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1742                let impl_id = impl_impl_id.impl_id();
1743                return Ok(match impl_id.long(self.db) {
1744                    ImplLongId::GenericParameter(_)
1745                    | ImplLongId::SelfImpl(_)
1746                    | ImplLongId::GeneratedImpl(_)
1747                    | ImplLongId::ImplImpl(_) => impl_impl_id_rewrite_result,
1748                    ImplLongId::Concrete(_) => {
1749                        if let Ok(imp) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1750                            *value = self.rewrite(imp).no_err().long(self.db).clone();
1751                            RewriteResult::Modified
1752                        } else {
1753                            impl_impl_id_rewrite_result
1754                        }
1755                    }
1756                    ImplLongId::ImplVar(var) => {
1757                        if let Ok(concrete_trait_impl) =
1758                            impl_impl_id.concrete_trait_impl_id(self.db)
1759                        {
1760                            *value = self
1761                                .rewritten_impl_impl(*var, concrete_trait_impl)
1762                                .long(self.db)
1763                                .clone();
1764                            return Ok(RewriteResult::Modified);
1765                        } else {
1766                            impl_impl_id_rewrite_result
1767                        }
1768                    }
1769                });
1770            }
1771
1772            _ => {}
1773        }
1774        if value.is_var_free(self.db) {
1775            return Ok(RewriteResult::NoChange);
1776        }
1777        value.default_rewrite(self)
1778    }
1779}
1780
1781impl<'db, 'mt> SemanticRewriter<NegativeImplLongId<'db>, NoError> for Inference<'db, 'mt> {
1782    fn internal_rewrite(
1783        &mut self,
1784        value: &mut NegativeImplLongId<'db>,
1785    ) -> Result<RewriteResult, NoError> {
1786        if let NegativeImplLongId::NegativeImplVar(var) = value {
1787            let long_id = var.long(self.db);
1788            // Relax the candidates.
1789            let neg_impl_var_id = long_id.id;
1790            if let Some(impl_id) = self.negative_impl_assignment(neg_impl_var_id) {
1791                let mut long_neg_impl_id = impl_id.long(self.db).clone();
1792                if let RewriteResult::Modified = self.internal_rewrite(&mut long_neg_impl_id)? {
1793                    *self.negative_impl_assignment.get_mut(&neg_impl_var_id).unwrap() =
1794                        long_neg_impl_id.clone().intern(self.db);
1795                }
1796                *value = long_neg_impl_id;
1797                return Ok(RewriteResult::Modified);
1798            }
1799        }
1800
1801        if value.is_var_free(self.db) {
1802            return Ok(RewriteResult::NoChange);
1803        }
1804        value.default_rewrite(self)
1805    }
1806}
1807
1808struct InferenceIdReplacer<'a> {
1809    db: &'a dyn Database,
1810    from_inference_id: InferenceId<'a>,
1811    to_inference_id: InferenceId<'a>,
1812}
1813impl<'a> InferenceIdReplacer<'a> {
1814    fn new(
1815        db: &'a dyn Database,
1816        from_inference_id: InferenceId<'a>,
1817        to_inference_id: InferenceId<'a>,
1818    ) -> Self {
1819        Self { db, from_inference_id, to_inference_id }
1820    }
1821}
1822impl<'a> HasDb<&'a dyn Database> for InferenceIdReplacer<'a> {
1823    fn get_db(&self) -> &'a dyn Database {
1824        self.db
1825    }
1826}
1827add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1828add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1829add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity<'a>);
1830impl<'a> SemanticRewriter<InferenceId<'a>, NoError> for InferenceIdReplacer<'a> {
1831    fn internal_rewrite(&mut self, value: &mut InferenceId<'a>) -> Result<RewriteResult, NoError> {
1832        if value == &self.from_inference_id {
1833            *value = self.to_inference_id;
1834            Ok(RewriteResult::Modified)
1835        } else {
1836            Ok(RewriteResult::NoChange)
1837        }
1838    }
1839}