cairo_lang_semantic/expr/
inference.rs

1//! Bidirectional type inference.
2
3use std::collections::{BTreeMap, HashMap};
4use std::hash::Hash;
5use std::ops::{Deref, DerefMut};
6use std::sync::Arc;
7
8use cairo_lang_debug::DebugWithDb;
9use cairo_lang_defs::ids::{
10    ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId,
11    GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LocalVarId, LookupItemId,
12    MacroCallId, MemberId, NamedLanguageElementId, ParamId, StructId, TraitConstantId,
13    TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
14};
15use cairo_lang_diagnostics::{DiagnosticAdded, skip_diagnostic};
16use cairo_lang_proc_macros::{DebugWithDb, HeapSize, SemanticObject};
17use cairo_lang_syntax::node::TypedStablePtr;
18use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
19use cairo_lang_utils::deque::Deque;
20use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
21use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
22use cairo_lang_utils::{Intern, define_short_id, extract_matches};
23use salsa::Database;
24
25use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
26use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
27use crate::corelib::CorelibSemantic;
28use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
29use crate::expr::inference::canonic::ResultNoErrEx;
30use crate::expr::inference::conform::InferenceConform;
31use crate::expr::inference::solver::SemanticSolver;
32use crate::expr::objects::*;
33use crate::expr::pattern::*;
34use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
35use crate::items::functions::{
36    ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
37    GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
38    ImplGenericFunctionWithBodyId,
39};
40use crate::items::generics::{
41    GenericParamConst, GenericParamImpl, GenericParamSemantic, GenericParamType,
42};
43use crate::items::imp::{
44    GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
45    ImplLookupContextId, ImplSemantic, NegativeImplId, NegativeImplLongId,
46    UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
47};
48use crate::items::trt::{
49    ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId, ConcreteTraitTypeId,
50    ConcreteTraitTypeLongId,
51};
52use crate::substitution::{GenericSubstitution, HasDb, RewriteResult, SemanticRewriter};
53use crate::types::{
54    ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
55    ImplTypeById, ImplTypeId,
56};
57use crate::{
58    ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
59    ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
60    FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
61    Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
62    add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
63};
64
65pub mod canonic;
66pub mod conform;
67pub mod infers;
68pub mod solver;
69
70/// A type variable, created when a generic type argument is not passed, and thus is not known
71/// yet and needs to be inferred.
72#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, HeapSize, salsa::Update)]
73pub struct TypeVar<'db> {
74    pub inference_id: InferenceId<'db>,
75    pub id: LocalTypeVarId,
76}
77
78/// A const variable, created when a generic const argument is not passed, and thus is not known
79/// yet and needs to be inferred.
80#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, HeapSize, salsa::Update)]
81pub struct ConstVar<'db> {
82    pub inference_id: InferenceId<'db>,
83    pub id: LocalConstVarId,
84}
85
86/// An id for an inference context. Each inference variable is associated with an inference id.
87#[derive(
88    Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, HeapSize, salsa::Update,
89)]
90#[debug_db(dyn Database)]
91pub enum InferenceId<'db> {
92    LookupItemDeclaration(LookupItemId<'db>),
93    LookupItemGenerics(LookupItemId<'db>),
94    LookupItemDefinition(LookupItemId<'db>),
95    ImplDefTrait(ImplDefId<'db>),
96    ImplAliasImplDef(ImplAliasId<'db>),
97    GenericParam(GenericParamId<'db>),
98    GenericImplParamTrait(GenericParamId<'db>),
99    GlobalUseStar(GlobalUseId<'db>),
100    MacroCall(MacroCallId<'db>),
101    Canonical,
102    /// For resolving that will not be used anywhere in the semantic model.
103    NoContext,
104}
105
106/// An impl variable, created when a generic type argument is not passed, and thus is not known
107/// yet and needs to be inferred.
108#[derive(
109    Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, HeapSize, salsa::Update,
110)]
111#[debug_db(dyn Database)]
112pub struct ImplVar<'db> {
113    pub inference_id: InferenceId<'db>,
114    #[dont_rewrite]
115    pub id: LocalImplVarId,
116    pub concrete_trait_id: ConcreteTraitId<'db>,
117    #[dont_rewrite]
118    pub lookup_context: ImplLookupContextId<'db>,
119}
120impl<'db> ImplVar<'db> {
121    pub fn intern(&self, db: &'db dyn Database) -> ImplVarId<'db> {
122        self.clone().intern(db)
123    }
124}
125
126/// A negative impl variable
127#[derive(
128    Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, HeapSize, salsa::Update,
129)]
130#[debug_db(dyn Database)]
131pub struct NegativeImplVar<'db> {
132    pub inference_id: InferenceId<'db>,
133    #[dont_rewrite]
134    pub id: LocalNegativeImplVarId,
135    pub concrete_trait_id: ConcreteTraitId<'db>,
136    #[dont_rewrite]
137    pub lookup_context: ImplLookupContextId<'db>,
138}
139
140#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
141pub struct LocalTypeVarId(pub usize);
142#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
143pub struct LocalImplVarId(pub usize);
144
145#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
146pub struct LocalNegativeImplVarId(pub usize);
147
148#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
149pub struct LocalConstVarId(pub usize);
150
151define_short_id!(ImplVarId, ImplVar<'db>);
152impl<'db> ImplVarId<'db> {
153    pub fn id(&self, db: &dyn Database) -> LocalImplVarId {
154        self.long(db).id
155    }
156    pub fn concrete_trait_id(&self, db: &'db dyn Database) -> ConcreteTraitId<'db> {
157        self.long(db).concrete_trait_id
158    }
159    pub fn lookup_context(&self, db: &'db dyn Database) -> ImplLookupContextId<'db> {
160        self.long(db).lookup_context
161    }
162}
163semantic_object_for_id!(ImplVarId, ImplVar<'a>);
164
165define_short_id!(NegativeImplVarId, NegativeImplVar<'db>);
166impl<'db> NegativeImplVarId<'db> {
167    pub fn id(&self, db: &dyn Database) -> LocalNegativeImplVarId {
168        self.long(db).id
169    }
170    pub fn concrete_trait_id(&self, db: &'db dyn Database) -> ConcreteTraitId<'db> {
171        self.long(db).concrete_trait_id
172    }
173    pub fn lookup_context(&self, db: &'db dyn Database) -> ImplLookupContextId<'db> {
174        self.long(db).lookup_context
175    }
176}
177semantic_object_for_id!(NegativeImplVarId, NegativeImplVar<'a>);
178
179#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject, salsa::Update)]
180pub enum InferenceVar {
181    Type(LocalTypeVarId),
182    Const(LocalConstVarId),
183    Impl(LocalImplVarId),
184    NegativeImpl(LocalNegativeImplVarId),
185}
186
187// TODO(spapini): Add to diagnostics.
188#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb, salsa::Update)]
189#[debug_db(dyn Database)]
190pub enum InferenceError<'db> {
191    /// An inference error wrapping a previously reported error.
192    Reported(DiagnosticAdded),
193    Cycle(InferenceVar),
194    TypeKindMismatch {
195        ty0: TypeId<'db>,
196        ty1: TypeId<'db>,
197    },
198    ConstKindMismatch {
199        const0: ConstValueId<'db>,
200        const1: ConstValueId<'db>,
201    },
202    ImplKindMismatch {
203        impl0: ImplId<'db>,
204        impl1: ImplId<'db>,
205    },
206    NegativeImplKindMismatch {
207        impl0: NegativeImplId<'db>,
208        impl1: NegativeImplId<'db>,
209    },
210    GenericArgMismatch {
211        garg0: GenericArgumentId<'db>,
212        garg1: GenericArgumentId<'db>,
213    },
214    TraitMismatch {
215        trt0: TraitId<'db>,
216        trt1: TraitId<'db>,
217    },
218    ImplTypeMismatch {
219        impl_id: ImplId<'db>,
220        trait_type_id: TraitTypeId<'db>,
221        ty0: TypeId<'db>,
222        ty1: TypeId<'db>,
223    },
224    GenericFunctionMismatch {
225        func0: GenericFunctionId<'db>,
226        func1: GenericFunctionId<'db>,
227    },
228    ConstNotInferred,
229    // TODO(spapini): These are only used for external interface. Separate them along with the
230    // finalize() function to a wrapper.
231    NoImplsFound(ConcreteTraitId<'db>),
232    NoNegativeImplsFound(ConcreteTraitId<'db>),
233    Ambiguity(Ambiguity<'db>),
234    TypeNotInferred(TypeId<'db>),
235}
236impl<'db> InferenceError<'db> {
237    pub fn format(&self, db: &dyn Database) -> String {
238        match self {
239            InferenceError::Reported(_) => "Inference error occurred.".into(),
240            InferenceError::Cycle(_var) => "Inference cycle detected".into(),
241            InferenceError::TypeKindMismatch { ty0, ty1 } => {
242                format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
243            }
244            InferenceError::ConstKindMismatch { const0, const1 } => {
245                format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
246            }
247            InferenceError::ImplKindMismatch { impl0, impl1 } => {
248                format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
249            }
250            InferenceError::NegativeImplKindMismatch { impl0, impl1 } => {
251                format!(
252                    "Negative impl mismatch: `{:?}` and `{:?}`.",
253                    impl0.debug(db),
254                    impl1.debug(db)
255                )
256            }
257            InferenceError::GenericArgMismatch { garg0, garg1 } => {
258                format!(
259                    "Generic arg mismatch: `{:?}` and `{:?}`.",
260                    garg0.debug(db),
261                    garg1.debug(db)
262                )
263            }
264            InferenceError::TraitMismatch { trt0, trt1 } => {
265                format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
266            }
267            InferenceError::ConstNotInferred => "Failed to infer constant.".into(),
268            InferenceError::NoImplsFound(concrete_trait_id) => {
269                let info = db.core_info();
270                let trait_id = concrete_trait_id.trait_id(db);
271                if trait_id == info.numeric_literal_trt {
272                    let generic_type = extract_matches!(
273                        concrete_trait_id.generic_args(db)[0],
274                        GenericArgumentId::Type
275                    );
276                    return format!(
277                        "Mismatched types. The type `{:?}` cannot be created from a numeric \
278                         literal.",
279                        generic_type.debug(db)
280                    );
281                } else if trait_id == info.string_literal_trt {
282                    let generic_type = extract_matches!(
283                        concrete_trait_id.generic_args(db)[0],
284                        GenericArgumentId::Type
285                    );
286                    return format!(
287                        "Mismatched types. The type `{:?}` cannot be created from a string \
288                         literal.",
289                        generic_type.debug(db)
290                    );
291                }
292                format!(
293                    "Trait has no implementation in context: {:?}.",
294                    concrete_trait_id.debug(db)
295                )
296            }
297            InferenceError::NoNegativeImplsFound(concrete_trait_id) => {
298                format!("Trait has implementation in context: {:?}.", concrete_trait_id.debug(db))
299            }
300            InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
301            InferenceError::TypeNotInferred(ty) => {
302                format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
303            }
304            InferenceError::GenericFunctionMismatch { func0, func1 } => {
305                format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
306            }
307            InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 } => {
308                format!(
309                    "`{}::{}` type mismatch: `{:?}` and `{:?}`.",
310                    impl_id.format(db),
311                    trait_type_id.name(db).long(db),
312                    ty0.debug(db),
313                    ty1.debug(db)
314                )
315            }
316        }
317    }
318}
319
320impl<'db> InferenceError<'db> {
321    pub fn report(
322        &self,
323        diagnostics: &mut SemanticDiagnostics<'db>,
324        stable_ptr: SyntaxStablePtrId<'db>,
325    ) -> DiagnosticAdded {
326        match self {
327            InferenceError::Reported(diagnostic_added) => *diagnostic_added,
328            _ => diagnostics
329                .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
330        }
331    }
332}
333
334/// This struct is used to ensure that when an inference error occurs, it is properly set in the
335/// `Inference` object, and then properly consumed.
336///
337/// It must not be constructed directly. Instead, it is returned by [Inference::set_error].
338#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
339pub struct ErrorSet;
340
341pub type InferenceResult<T> = Result<T, ErrorSet>;
342
343#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
344enum InferenceErrorStatus<'db> {
345    /// There is a pending error.
346    Pending(PendingInferenceError<'db>),
347    /// There was an error but it was already consumed.
348    Consumed(DiagnosticAdded),
349}
350
351/// A pending inference error.
352#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
353struct PendingInferenceError<'db> {
354    /// The actual error.
355    err: InferenceError<'db>,
356    /// The optional location of the error.
357    stable_ptr: Option<SyntaxStablePtrId<'db>>,
358}
359
360/// A mapping of an impl var's trait items to concrete items
361#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, SemanticObject, salsa::Update)]
362pub struct ImplVarTraitItemMappings<'db> {
363    /// The trait types of the impl var.
364    types: OrderedHashMap<TraitTypeId<'db>, TypeId<'db>>,
365    /// The trait constants of the impl var.
366    constants: OrderedHashMap<TraitConstantId<'db>, ConstValueId<'db>>,
367    /// The trait impls of the impl var.
368    impls: OrderedHashMap<TraitImplId<'db>, ImplId<'db>>,
369}
370
371impl ImplVarTraitItemMappings<'_> {
372    /// Returns `true` if the impl var has no mappings.
373    pub fn is_empty(&self) -> bool {
374        self.types.is_empty() && self.constants.is_empty() && self.impls.is_empty()
375    }
376}
377
378/// State of inference.
379#[derive(Debug, DebugWithDb, PartialEq, Eq, salsa::Update)]
380#[debug_db(dyn Database)]
381pub struct InferenceData<'db> {
382    pub inference_id: InferenceId<'db>,
383    /// Current inferred assignment for type variables.
384    pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId<'db>>,
385    /// Current inferred assignment for const variables.
386    pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId<'db>>,
387    /// Current inferred assignment for impl variables.
388    pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId<'db>>,
389    /// Current inferred assignment for negative impl variables.
390    pub negative_impl_assignment: OrderedHashMap<LocalNegativeImplVarId, NegativeImplId<'db>>,
391    /// Unsolved impl variables mapping to a maps of trait items to a corresponding item variable.
392    /// Upon solution of the trait conforms the fully known item to the variable.
393    pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings<'db>>,
394    /// Type variables.
395    pub type_vars: Vec<TypeVar<'db>>,
396    /// Const variables.
397    pub const_vars: Vec<ConstVar<'db>>,
398    /// Impl variables.
399    pub impl_vars: Vec<ImplVar<'db>>,
400    /// Negative impl variables.
401    pub negative_impl_vars: Vec<NegativeImplVar<'db>>,
402    /// Mapping from variables to stable pointers, if exist.
403    pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId<'db>>,
404    /// Inference variables that are pending to be solved.
405    pending: Deque<LocalImplVarId>,
406    /// Inference negative variables that are pending to be solved.
407    negative_pending: Deque<LocalNegativeImplVarId>,
408    /// Inference variables that have been refuted - no solutions exist.
409    refuted: Vec<LocalImplVarId>,
410    /// Inference negative variables that have been refuted - no solutions exist.
411    negative_refuted: Vec<LocalNegativeImplVarId>,
412    /// Inference variables that have been solved.
413    solved: Vec<LocalImplVarId>,
414    /// Inference variables that are currently ambiguous. May be solved later.
415    ambiguous: Vec<(LocalImplVarId, Ambiguity<'db>)>,
416    /// Negative impl inference variables that are currently ambiguous. May be solved later.
417    negative_ambiguous: Vec<(LocalNegativeImplVarId, Ambiguity<'db>)>,
418    /// Mapping from impl types to type variables.
419    pub impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
420
421    /// The current error status.
422    error_status: Result<(), InferenceErrorStatus<'db>>,
423}
424impl<'db> InferenceData<'db> {
425    pub fn new(inference_id: InferenceId<'db>) -> Self {
426        Self {
427            inference_id,
428            type_assignment: OrderedHashMap::default(),
429            impl_assignment: OrderedHashMap::default(),
430            const_assignment: OrderedHashMap::default(),
431            negative_impl_assignment: OrderedHashMap::default(),
432            impl_vars_trait_item_mappings: HashMap::new(),
433            type_vars: Vec::new(),
434            impl_vars: Vec::new(),
435            const_vars: Vec::new(),
436            negative_impl_vars: Vec::new(),
437            stable_ptrs: HashMap::new(),
438            pending: Deque::new(),
439            negative_pending: Deque::new(),
440            refuted: Vec::new(),
441            negative_refuted: Vec::new(),
442            solved: Vec::new(),
443            ambiguous: Vec::new(),
444            negative_ambiguous: Vec::new(),
445            impl_type_bounds: Default::default(),
446            error_status: Ok(()),
447        }
448    }
449    pub fn inference<'r>(&'r mut self, db: &'db dyn Database) -> Inference<'db, 'r> {
450        Inference::new(db, self)
451    }
452    pub fn clone_with_inference_id(
453        &self,
454        db: &'db dyn Database,
455        inference_id: InferenceId<'db>,
456    ) -> InferenceData<'db> {
457        let mut inference_id_replacer =
458            InferenceIdReplacer::new(db, self.inference_id, inference_id);
459        Self {
460            inference_id,
461            type_assignment: self
462                .type_assignment
463                .iter()
464                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
465                .collect(),
466            const_assignment: self
467                .const_assignment
468                .iter()
469                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
470                .collect(),
471            impl_assignment: self
472                .impl_assignment
473                .iter()
474                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
475                .collect(),
476            negative_impl_assignment: self
477                .negative_impl_assignment
478                .iter()
479                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
480                .collect(),
481            impl_vars_trait_item_mappings: self
482                .impl_vars_trait_item_mappings
483                .iter()
484                .map(|(k, mappings)| {
485                    (
486                        *k,
487                        ImplVarTraitItemMappings {
488                            types: mappings
489                                .types
490                                .iter()
491                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
492                                .collect(),
493                            constants: mappings
494                                .constants
495                                .iter()
496                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
497                                .collect(),
498                            impls: mappings
499                                .impls
500                                .iter()
501                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
502                                .collect(),
503                        },
504                    )
505                })
506                .collect(),
507            type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
508            const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
509            impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
510            negative_impl_vars: inference_id_replacer
511                .rewrite(self.negative_impl_vars.clone())
512                .no_err(),
513            stable_ptrs: self.stable_ptrs.clone(),
514            pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
515            negative_pending: inference_id_replacer.rewrite(self.negative_pending.clone()).no_err(),
516            refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
517            negative_refuted: inference_id_replacer.rewrite(self.negative_refuted.clone()).no_err(),
518            solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
519            ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
520            negative_ambiguous: inference_id_replacer
521                .rewrite(self.negative_ambiguous.clone())
522                .no_err(),
523            // we do not need to rewrite the impl type bounds, as they all should be var free.
524            impl_type_bounds: self.impl_type_bounds.clone(),
525
526            error_status: self.error_status.clone(),
527        }
528    }
529    pub fn temporary_clone(&self) -> InferenceData<'db> {
530        Self {
531            inference_id: self.inference_id,
532            type_assignment: self.type_assignment.clone(),
533            const_assignment: self.const_assignment.clone(),
534            impl_assignment: self.impl_assignment.clone(),
535            negative_impl_assignment: self.negative_impl_assignment.clone(),
536            impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
537            type_vars: self.type_vars.clone(),
538            const_vars: self.const_vars.clone(),
539            impl_vars: self.impl_vars.clone(),
540            negative_impl_vars: self.negative_impl_vars.clone(),
541            stable_ptrs: self.stable_ptrs.clone(),
542            pending: self.pending.clone(),
543            negative_pending: self.negative_pending.clone(),
544            refuted: self.refuted.clone(),
545            negative_refuted: self.negative_refuted.clone(),
546            solved: self.solved.clone(),
547            ambiguous: self.ambiguous.clone(),
548            negative_ambiguous: self.negative_ambiguous.clone(),
549            impl_type_bounds: self.impl_type_bounds.clone(),
550            error_status: self.error_status.clone(),
551        }
552    }
553}
554
555/// State of inference. A system of inference constraints.
556pub struct Inference<'db, 'id> {
557    db: &'db dyn Database,
558    pub data: &'id mut InferenceData<'db>,
559}
560
561impl<'db, 'id> Deref for Inference<'db, 'id> {
562    type Target = InferenceData<'db>;
563
564    fn deref(&self) -> &Self::Target {
565        self.data
566    }
567}
568impl DerefMut for Inference<'_, '_> {
569    fn deref_mut(&mut self) -> &mut Self::Target {
570        self.data
571    }
572}
573
574impl std::fmt::Debug for Inference<'_, '_> {
575    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
576        let x = self.data.debug(self.db);
577        write!(f, "{x:?}")
578    }
579}
580
581impl<'db, 'id> Inference<'db, 'id> {
582    fn new(db: &'db dyn Database, data: &'id mut InferenceData<'db>) -> Self {
583        Self { db, data }
584    }
585
586    /// Getter for an [ImplVar].
587    fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar<'db> {
588        &self.impl_vars[var_id.0]
589    }
590
591    /// Getter for an impl var assignment.
592    pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId<'db>> {
593        self.impl_assignment.get(&var_id).copied()
594    }
595
596    /// Getter for a [NegativeImplVar].
597    fn negative_impl_var(&self, var_id: LocalNegativeImplVarId) -> &NegativeImplVar<'db> {
598        &self.negative_impl_vars[var_id.0]
599    }
600
601    /// Getter for a negative impl var assignment.
602    pub fn negative_impl_assignment(
603        &self,
604        var_id: LocalNegativeImplVarId,
605    ) -> Option<NegativeImplId<'db>> {
606        self.negative_impl_assignment.get(&var_id).copied()
607    }
608
609    /// Getter for a type var assignment.
610    fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId<'db>> {
611        self.type_assignment.get(&var_id).copied()
612    }
613
614    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
615    /// Returns a wrapping TypeId.
616    pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId<'db>>) -> TypeId<'db> {
617        let var = self.new_type_var_raw(stable_ptr);
618
619        TypeLongId::Var(var).intern(self.db)
620    }
621
622    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
623    /// Returns the variable id.
624    pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId<'db>>) -> TypeVar<'db> {
625        let var =
626            TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
627        if let Some(stable_ptr) = stable_ptr {
628            self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
629        }
630        self.type_vars.push(var);
631        var
632    }
633
634    /// Sets the inference's impl type bounds to the given map, and rewrittes the types so all the
635    /// types are var free.
636    pub fn set_impl_type_bounds(
637        &mut self,
638        impl_type_bounds: OrderedHashMap<ImplTypeId<'db>, TypeId<'db>>,
639    ) {
640        let impl_type_bounds_finalized = impl_type_bounds
641            .iter()
642            .filter_map(|(impl_type, ty)| {
643                let rewritten_type = self.rewrite(ty.long(self.db).clone()).no_err();
644                if !matches!(rewritten_type, TypeLongId::Var(_)) {
645                    return Some(((*impl_type).into(), rewritten_type.intern(self.db)));
646                }
647                // conformed the var type to the original impl type to remove it from the pending
648                // list.
649                self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
650                None
651            })
652            .collect();
653
654        self.data.impl_type_bounds = Arc::new(impl_type_bounds_finalized);
655    }
656
657    /// Allocates a new [ConstVar] for an unknown consts that needs to be inferred.
658    /// Returns a wrapping [ConstValueId].
659    pub fn new_const_var(
660        &mut self,
661        stable_ptr: Option<SyntaxStablePtrId<'db>>,
662        ty: TypeId<'db>,
663    ) -> ConstValueId<'db> {
664        let var = self.new_const_var_raw(stable_ptr);
665        ConstValue::Var(var, ty).intern(self.db)
666    }
667
668    /// Allocates a new [ConstVar] for an unknown type that needs to be inferred.
669    /// Returns the variable id.
670    pub fn new_const_var_raw(
671        &mut self,
672        stable_ptr: Option<SyntaxStablePtrId<'db>>,
673    ) -> ConstVar<'db> {
674        let var = ConstVar {
675            inference_id: self.inference_id,
676            id: LocalConstVarId(self.const_vars.len()),
677        };
678        if let Some(stable_ptr) = stable_ptr {
679            self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
680        }
681        self.const_vars.push(var);
682        var
683    }
684
685    /// Allocates a new [ImplVar] for an unknown type that needs to be inferred.
686    /// Returns a wrapping ImplId.
687    pub fn new_impl_var(
688        &mut self,
689        concrete_trait_id: ConcreteTraitId<'db>,
690        stable_ptr: Option<SyntaxStablePtrId<'db>>,
691        lookup_context: ImplLookupContextId<'db>,
692    ) -> ImplId<'db> {
693        let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
694        ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
695    }
696
697    /// Allocates a new [ImplVar] for an unknown impl that needs to be inferred.
698    /// Returns the variable id.
699    fn new_impl_var_raw(
700        &mut self,
701        lookup_context: ImplLookupContextId<'db>,
702        concrete_trait_id: ConcreteTraitId<'db>,
703        stable_ptr: Option<SyntaxStablePtrId<'db>>,
704    ) -> LocalImplVarId {
705        let id = LocalImplVarId(self.impl_vars.len());
706        if let Some(stable_ptr) = stable_ptr {
707            self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
708        }
709        let var =
710            ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
711        self.impl_vars.push(var);
712        self.pending.push_back(id);
713        id
714    }
715
716    /// Allocates a new [NegativeImplVar] for an unknown negative impl that needs to be inferred.
717    /// Returns a wrapping NegativeImplId.
718    pub fn new_negative_impl_var(
719        &mut self,
720        concrete_trait_id: ConcreteTraitId<'db>,
721        stable_ptr: Option<SyntaxStablePtrId<'db>>,
722        lookup_context: ImplLookupContextId<'db>,
723    ) -> NegativeImplId<'db> {
724        let var = self.new_negative_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
725        NegativeImplLongId::NegativeImplVar(self.negative_impl_var(var).clone().intern(self.db))
726            .intern(self.db)
727    }
728
729    /// Allocates a new [NegativeImplVar] for an unknown type that needs to be inferred.
730    /// Returns the variable id.
731    fn new_negative_impl_var_raw(
732        &mut self,
733        lookup_context: ImplLookupContextId<'db>,
734        concrete_trait_id: ConcreteTraitId<'db>,
735        stable_ptr: Option<SyntaxStablePtrId<'db>>,
736    ) -> LocalNegativeImplVarId {
737        let id = LocalNegativeImplVarId(self.negative_impl_vars.len());
738        if let Some(stable_ptr) = stable_ptr {
739            self.stable_ptrs.insert(InferenceVar::NegativeImpl(id), stable_ptr);
740        }
741        let var = NegativeImplVar {
742            inference_id: self.inference_id,
743            id,
744            concrete_trait_id,
745            lookup_context,
746        };
747        self.negative_impl_vars.push(var);
748        self.negative_pending.push_back(id);
749        id
750    }
751
752    /// Solves the inference system. After a successful solve, there are no more pending impl
753    /// inferences.
754    /// Returns whether the inference was successful. If not, the error may be found by
755    /// `.error_state()`.
756    pub fn solve(&mut self) -> InferenceResult<()> {
757        let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
758        self.pending.extend(ambiguous.map(|(var, _)| var));
759        while let Some(var) = self.pending.pop_front() {
760            // First inference error stops inference.
761            self.solve_single_pending(var).inspect_err(|_err_set| {
762                self.add_error_stable_ptr(InferenceVar::Impl(var));
763            })?;
764        }
765        while let Some(var) = self.negative_pending.pop_front() {
766            // First inference error stops inference.
767            self.solve_single_negative_pending(var).inspect_err(|_err_set| {
768                self.add_error_stable_ptr(InferenceVar::NegativeImpl(var));
769            })?;
770        }
771        Ok(())
772    }
773
774    fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
775        if self.impl_assignment.contains_key(&var) {
776            return Ok(());
777        }
778        let solution = match self.impl_var_solution_set(var)? {
779            SolutionSet::None => {
780                self.refuted.push(var);
781                return Ok(());
782            }
783            SolutionSet::Ambiguous(ambiguity) => {
784                self.ambiguous.push((var, ambiguity));
785                return Ok(());
786            }
787            SolutionSet::Unique(solution) => solution,
788        };
789
790        // Solution found. Assign it.
791        self.assign_local_impl(var, solution)?;
792
793        // Something changed.
794        self.solved.push(var);
795        let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
796        self.pending.extend(ambiguous.map(|(var, _)| var));
797
798        let negative_ambiguous = std::mem::take(&mut self.negative_ambiguous).into_iter();
799        self.negative_pending.extend(negative_ambiguous.map(|(var, _)| var));
800
801        Ok(())
802    }
803
804    /// Solves a single negative impl pending variable.
805    fn solve_single_negative_pending(
806        &mut self,
807        var: LocalNegativeImplVarId,
808    ) -> InferenceResult<()> {
809        if self.negative_impl_assignment.contains_key(&var) {
810            return Ok(());
811        }
812
813        let solution = match self.negative_impl_var_solution_set(var)? {
814            SolutionSet::None => {
815                self.negative_refuted.push(var);
816                return Ok(());
817            }
818            SolutionSet::Ambiguous(ambiguity) => {
819                self.negative_ambiguous.push((var, ambiguity));
820                return Ok(());
821            }
822            SolutionSet::Unique(solution) => solution,
823        };
824
825        // Solution found. Assign it.
826        self.assign_local_negative_impl(var, solution)?;
827
828        Ok(())
829    }
830
831    /// Returns the solution set status for the inference:
832    /// Whether there is a unique solution, multiple solutions, no solutions or an error.
833    pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<'db, ()>> {
834        self.solve()?;
835        if !self.refuted.is_empty() {
836            return Ok(SolutionSet::None);
837        }
838        if !self.negative_refuted.is_empty() {
839            return Ok(SolutionSet::None);
840        }
841        if let Some((_, ambiguity)) = self.ambiguous.first() {
842            return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
843        }
844        if let Some((_, ambiguity)) = self.negative_ambiguous.first() {
845            return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
846        }
847        assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
848        assert!(self.negative_pending.is_empty(), "solution() called on an unsolved solver");
849        Ok(SolutionSet::Unique(()))
850    }
851
852    /// Finalizes the inference by inferring uninferred numeric literals as felt252.
853    /// Returns an error and does not report it.
854    pub fn finalize_without_reporting(&mut self) -> Result<(), ErrorSet> {
855        if self.error_status.is_err() {
856            return Err(ErrorSet);
857        }
858        let info = self.db.core_info();
859        let numeric_trait_id = info.numeric_literal_trt;
860        let felt_ty = info.felt252;
861
862        // Conform all uninferred numeric literals to felt252.
863        loop {
864            let mut changed = false;
865            self.solve()?;
866            for (var, _) in self.ambiguous.clone() {
867                let impl_var = self.impl_var(var).clone();
868                if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
869                    continue;
870                }
871                // Uninferred numeric trait. Resolve as felt252.
872                let ty = extract_matches!(
873                    impl_var.concrete_trait_id.generic_args(self.db)[0],
874                    GenericArgumentId::Type
875                );
876                if self.rewrite(ty).no_err() == felt_ty {
877                    continue;
878                }
879                self.conform_ty(ty, felt_ty).inspect_err(|_err_set| {
880                    self.add_error_stable_ptr(InferenceVar::Impl(impl_var.id));
881                })?;
882                changed = true;
883                break;
884            }
885            if !changed {
886                break;
887            }
888        }
889        assert!(
890            self.pending.is_empty(),
891            "pending should all be solved by this point. Guaranteed by solve()."
892        );
893
894        let Some((var, err)) = self.first_undetermined_variable() else {
895            return Ok(());
896        };
897        Err(self.set_error_on_var(err, var))
898    }
899
900    /// Finalizes the inference and report diagnostics if there are any errors.
901    /// All the remaining type vars are mapped to the `missing` type, to prevent additional
902    /// diagnostics.
903    pub fn finalize<'m>(
904        &'m mut self,
905        diagnostics: &mut SemanticDiagnostics<'db>,
906        stable_ptr: SyntaxStablePtrId<'db>,
907    ) {
908        if let Err(err_set) = self.finalize_without_reporting() {
909            let diag = self.report_on_pending_error(err_set, diagnostics, stable_ptr);
910
911            let ty_missing = TypeId::missing(self.db, diag);
912            for var in &self.data.type_vars {
913                self.data.type_assignment.entry(var.id).or_insert(ty_missing);
914            }
915        }
916    }
917
918    /// Retrieves the first variable that is still not inferred, or None, if everything is
919    /// inferred.
920    /// Does not set the error but return it, which is ok as this is a private helper function.
921    fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError<'db>)> {
922        if let Some(var) = self.refuted.first().copied() {
923            let impl_var = self.impl_var(var).clone();
924            let concrete_trait_id = impl_var.concrete_trait_id;
925            let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
926            return Some((
927                InferenceVar::Impl(var),
928                InferenceError::NoImplsFound(concrete_trait_id),
929            ));
930        }
931        if let Some(var) = self.negative_refuted.first().copied() {
932            let negative_impl_var = self.negative_impl_var(var).clone();
933            let concrete_trait_id = negative_impl_var.concrete_trait_id;
934            let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
935            return Some((
936                InferenceVar::NegativeImpl(var),
937                InferenceError::NoNegativeImplsFound(concrete_trait_id),
938            ));
939        }
940
941        let mut fallback_ret = None;
942        if let Some((var, ambiguity)) = self.ambiguous.first() {
943            // Note: do not rewrite `ambiguity`, since it is expressed in canonical variables.
944            let ret =
945                Some((InferenceVar::Impl(*var), InferenceError::Ambiguity(ambiguity.clone())));
946            if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
947                return ret;
948            } else {
949                fallback_ret = ret;
950            }
951        }
952        if let Some((var, ambiguity)) = self.negative_ambiguous.first() {
953            let ret = Some((
954                InferenceVar::NegativeImpl(*var),
955                InferenceError::Ambiguity(ambiguity.clone()),
956            ));
957            if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
958                return ret;
959            } else {
960                fallback_ret = ret;
961            }
962        }
963        for (id, var) in self.type_vars.iter().enumerate() {
964            if self.type_assignment(LocalTypeVarId(id)).is_none() {
965                let ty = TypeLongId::Var(*var).intern(self.db);
966                return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
967            }
968        }
969        for (id, var) in self.const_vars.iter().enumerate() {
970            if !self.const_assignment.contains_key(&LocalConstVarId(id)) {
971                let infernence_var = InferenceVar::Const(var.id);
972                return Some((infernence_var, InferenceError::ConstNotInferred));
973            }
974        }
975        fallback_ret
976    }
977
978    /// Assigns a value to a local impl variable id. See assign_impl().
979    fn assign_local_impl(
980        &mut self,
981        var: LocalImplVarId,
982        impl_id: ImplId<'db>,
983    ) -> InferenceResult<ImplId<'db>> {
984        let concrete_trait = impl_id
985            .concrete_trait(self.db)
986            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
987        self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
988        if let Some(other_impl) = self.impl_assignment(var) {
989            return self.conform_impl(impl_id, other_impl);
990        }
991        if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
992        {
993            let inference_var = InferenceVar::Impl(var);
994            return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var));
995        }
996        self.impl_assignment.insert(var, impl_id);
997        if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
998            for (trait_type_id, ty) in mappings.types {
999                let impl_ty = self
1000                    .db
1001                    .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_type_id, self.db))
1002                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1003                if let Err(err_set) = self.conform_ty(ty, impl_ty) {
1004                    // Override the error with ImplTypeMismatch.
1005                    let ty0 = self.rewrite(ty).no_err();
1006                    let ty1 = self.rewrite(impl_ty).no_err();
1007
1008                    let err = InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 };
1009                    self.error_status = Err(InferenceErrorStatus::Pending(PendingInferenceError {
1010                        err,
1011                        stable_ptr: self.stable_ptrs.get(&InferenceVar::Impl(var)).cloned(),
1012                    }));
1013                    return Err(err_set);
1014                }
1015            }
1016            for (trait_constant, constant_id) in mappings.constants {
1017                let concrete_impl_constant = self
1018                    .db
1019                    .impl_constant_concrete_implized_value(ImplConstantId::new(
1020                        impl_id,
1021                        trait_constant,
1022                        self.db,
1023                    ))
1024                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1025                self.conform_const(constant_id, concrete_impl_constant)?;
1026            }
1027            for (trait_impl, inner_impl_id) in mappings.impls {
1028                let concrete_impl_impl = self
1029                    .db
1030                    .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
1031                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1032                self.conform_impl(inner_impl_id, concrete_impl_impl)?;
1033            }
1034        }
1035        Ok(impl_id)
1036    }
1037
1038    /// Tries to assigns value to an [ImplVarId]. Return the assigned impl, or an error.
1039    fn assign_impl(
1040        &mut self,
1041        var_id: ImplVarId<'db>,
1042        impl_id: ImplId<'db>,
1043    ) -> InferenceResult<ImplId<'db>> {
1044        let var = var_id.long(self.db);
1045        if var.inference_id != self.inference_id {
1046            return Err(self.set_error(InferenceError::ImplKindMismatch {
1047                impl0: ImplLongId::ImplVar(var_id).intern(self.db),
1048                impl1: impl_id,
1049            }));
1050        }
1051        self.assign_local_impl(var.id, impl_id)
1052    }
1053
1054    /// Assigns a value to a local negative impl variable id. See assign_neg_impl().
1055    fn assign_local_negative_impl(
1056        &mut self,
1057        var: LocalNegativeImplVarId,
1058        neg_impl_id: NegativeImplId<'db>,
1059    ) -> InferenceResult<NegativeImplId<'db>> {
1060        let concrete_trait = neg_impl_id
1061            .concrete_trait(self.db)
1062            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1063        self.conform_traits(self.negative_impl_var(var).concrete_trait_id, concrete_trait)?;
1064        if let Some(other_impl) = self.negative_impl_assignment(var) {
1065            return self.conform_neg_impl(neg_impl_id, other_impl);
1066        }
1067        if !neg_impl_id.is_var_free(self.db)
1068            && self.negative_impl_contains_var(neg_impl_id, InferenceVar::NegativeImpl(var))
1069        {
1070            return Err(self.set_error(InferenceError::Cycle(InferenceVar::NegativeImpl(var))));
1071        }
1072        self.negative_impl_assignment.insert(var, neg_impl_id);
1073        Ok(neg_impl_id)
1074    }
1075
1076    /// Tries to assigns value to an [NegativeImplVarId]. Return the assigned negative impl, or an
1077    /// error.
1078    fn assign_neg_impl(
1079        &mut self,
1080        var_id: NegativeImplVarId<'db>,
1081        neg_impl_id: NegativeImplId<'db>,
1082    ) -> InferenceResult<NegativeImplId<'db>> {
1083        let var = var_id.long(self.db);
1084        if var.inference_id != self.inference_id {
1085            return Err(self.set_error(InferenceError::NegativeImplKindMismatch {
1086                impl0: NegativeImplLongId::NegativeImplVar(var_id).intern(self.db),
1087                impl1: neg_impl_id,
1088            }));
1089        }
1090        self.assign_local_negative_impl(var.id, neg_impl_id)
1091    }
1092
1093    /// Assigns a value to a [TypeVar]. Return the assigned type, or an error.
1094    /// Assumes the variable is not already assigned.
1095    fn assign_ty(&mut self, var: TypeVar<'db>, ty: TypeId<'db>) -> InferenceResult<TypeId<'db>> {
1096        if var.inference_id != self.inference_id {
1097            return Err(self.set_error(InferenceError::TypeKindMismatch {
1098                ty0: TypeLongId::Var(var).intern(self.db),
1099                ty1: ty,
1100            }));
1101        }
1102        assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
1103        let inference_var = InferenceVar::Type(var.id);
1104        if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
1105            return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var));
1106        }
1107        // If assigning var to var - making sure assigning to the lower id for proper canonization.
1108        if let TypeLongId::Var(other) = ty.long(self.db)
1109            && other.inference_id == self.inference_id
1110            && other.id.0 > var.id.0
1111        {
1112            let var_ty = TypeLongId::Var(var).intern(self.db);
1113            self.type_assignment.insert(other.id, var_ty);
1114            return Ok(var_ty);
1115        }
1116        self.type_assignment.insert(var.id, ty);
1117        Ok(ty)
1118    }
1119
1120    /// Assigns a value to a [ConstVar]. Return the assigned const, or an error.
1121    /// Assumes the variable is not already assigned.
1122    fn assign_const(
1123        &mut self,
1124        var: ConstVar<'db>,
1125        id: ConstValueId<'db>,
1126    ) -> InferenceResult<ConstValueId<'db>> {
1127        if var.inference_id != self.inference_id {
1128            return Err(self.set_error(InferenceError::ConstKindMismatch {
1129                const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
1130                    .intern(self.db),
1131                const1: id,
1132            }));
1133        }
1134
1135        self.const_assignment.insert(var.id, id);
1136        Ok(id)
1137    }
1138
1139    /// Computes the solution set for an impl variable with a recursive query.
1140    fn impl_var_solution_set(
1141        &mut self,
1142        var: LocalImplVarId,
1143    ) -> InferenceResult<SolutionSet<'db, ImplId<'db>>> {
1144        let impl_var = self.impl_var(var).clone();
1145        // Update the concrete trait of the impl var.
1146        let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
1147        self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
1148        let impl_var_trait_item_mappings =
1149            self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
1150        let solution_set = self.trait_solution_set(
1151            concrete_trait_id,
1152            impl_var_trait_item_mappings,
1153            impl_var.lookup_context,
1154        )?;
1155        Ok(match solution_set {
1156            SolutionSet::None => SolutionSet::None,
1157            SolutionSet::Unique((canonical_impl, canonicalizer)) => {
1158                SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
1159            }
1160            SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
1161        })
1162    }
1163
1164    /// Computes the solution set for a negative impl variable with a recursive query.
1165    fn negative_impl_var_solution_set(
1166        &mut self,
1167        var: LocalNegativeImplVarId,
1168    ) -> InferenceResult<SolutionSet<'db, NegativeImplId<'db>>> {
1169        let negative_impl_var = self.negative_impl_var(var).clone();
1170        let concrete_trait_id = self.rewrite(negative_impl_var.concrete_trait_id).no_err();
1171
1172        let solution_set =
1173            self.validate_no_solution_set(concrete_trait_id, negative_impl_var.lookup_context)?;
1174        Ok(match solution_set {
1175            SolutionSet::Unique(concrete_trait_id) => {
1176                SolutionSet::Unique(NegativeImplLongId::Solved(concrete_trait_id).intern(self.db))
1177            }
1178            SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
1179            SolutionSet::None => SolutionSet::None,
1180        })
1181    }
1182
1183    /// Validates that no solution set is found for the negative impls.
1184    fn validate_no_solution_set(
1185        &mut self,
1186        concrete_trait_id: ConcreteTraitId<'db>,
1187        lookup_context: ImplLookupContextId<'db>,
1188    ) -> InferenceResult<SolutionSet<'db, ConcreteTraitId<'db>>> {
1189        for negative_impl in &lookup_context.long(self.db).negative_impls {
1190            let generic_param = self
1191                .db
1192                .generic_param_semantic(*negative_impl)
1193                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1194            if let GenericParam::NegImpl(neg_impl) = generic_param
1195                && Ok(concrete_trait_id) == neg_impl.concrete_trait
1196            {
1197                return Ok(SolutionSet::Unique(concrete_trait_id));
1198            }
1199        }
1200
1201        let mut neg_impl_generic_params = OrderedHashSet::default();
1202        for garg in concrete_trait_id.generic_args(self.db) {
1203            if garg.extract_generic_params(self.db, &mut neg_impl_generic_params).is_err() {
1204                return Ok(SolutionSet::Ambiguous(
1205                    Ambiguity::NegativeImplWithUnsupportedExtractedArgs(*garg),
1206                ));
1207            }
1208        }
1209
1210        let solution_set = if neg_impl_generic_params.is_empty() {
1211            self.trait_solution_set(
1212                concrete_trait_id,
1213                ImplVarTraitItemMappings::default(),
1214                lookup_context,
1215            )?
1216        } else {
1217            let substitution: OrderedHashMap<GenericParamId<'db>, GenericArgumentId<'db>> =
1218                neg_impl_generic_params
1219                    .into_iter()
1220                    .map(|param| {
1221                        (
1222                            param,
1223                            GenericArgumentId::Type(
1224                                self.new_type_var(Some(param.stable_ptr(self.db).untyped())),
1225                            ),
1226                        )
1227                    })
1228                    .collect();
1229            let rewritten_concrete_trait_id =
1230                GenericSubstitution { param_to_arg: substitution.clone(), self_impl: None }
1231                    .substitute(self.db, concrete_trait_id)
1232                    .unwrap();
1233
1234            let solution_set = self.trait_solution_set(
1235                rewritten_concrete_trait_id,
1236                ImplVarTraitItemMappings::default(),
1237                lookup_context,
1238            )?;
1239
1240            // Assign the newly created type variables with their corresponding generic parameters.
1241            // This step prevents leaving type variables unresolved.
1242            let db = self.db;
1243            for (generic_param, garg) in substitution {
1244                let GenericArgumentId::Type(ty) = garg else {
1245                    panic!("Expected a type variable");
1246                };
1247                let TypeLongId::Var(var) = ty.long(self.db) else {
1248                    panic!("Expected a type variable");
1249                };
1250                self.type_assignment
1251                    .entry(var.id)
1252                    .or_insert_with(|| TypeLongId::GenericParameter(generic_param).intern(db));
1253            }
1254
1255            solution_set
1256        };
1257
1258        if !matches!(solution_set, SolutionSet::None) {
1259            return Ok(SolutionSet::None);
1260        }
1261
1262        Ok(SolutionSet::Unique(concrete_trait_id))
1263    }
1264
1265    /// Computes the solution set for a trait with a recursive query.
1266    pub fn trait_solution_set(
1267        &mut self,
1268        concrete_trait_id: ConcreteTraitId<'db>,
1269        impl_var_trait_item_mappings: ImplVarTraitItemMappings<'db>,
1270        lookup_context_id: ImplLookupContextId<'db>,
1271    ) -> InferenceResult<SolutionSet<'db, (CanonicalImpl<'db>, CanonicalMapping<'db>)>> {
1272        let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
1273        // TODO(spapini): This is done twice. Consider doing it only here.
1274        let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
1275        let mut lookup_context = lookup_context_id.long(self.db).clone();
1276        enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
1277
1278        // Don't try to resolve impls if the first generic param is a variable.
1279        let generic_args = concrete_trait_id.generic_args(self.db);
1280        match generic_args.first() {
1281            Some(GenericArgumentId::Type(ty)) => {
1282                if let TypeLongId::Var(_) = ty.long(self.db) {
1283                    // Don't try to infer such impls.
1284                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1285                }
1286            }
1287            Some(GenericArgumentId::Impl(imp)) => {
1288                // Don't try to infer such impls.
1289                if let ImplLongId::ImplVar(_) = imp.long(self.db) {
1290                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1291                }
1292            }
1293            Some(GenericArgumentId::Constant(const_value)) => {
1294                if let ConstValue::Var(_, _) = const_value.long(self.db) {
1295                    // Don't try to infer such impls.
1296                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1297                }
1298            }
1299            _ => {}
1300        };
1301        let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
1302            self.db,
1303            self.inference_id,
1304            concrete_trait_id,
1305            impl_var_trait_item_mappings,
1306        );
1307        // impl_type_bounds order is deterimend by the generic params of the function and therefore
1308        // is consistent.
1309        let solution_set = match self.db.canonic_trait_solutions(
1310            canonical_trait,
1311            lookup_context.intern(self.db),
1312            (*self.data.impl_type_bounds).clone(),
1313        ) {
1314            Ok(solution_set) => solution_set,
1315            Err(err) => return Err(self.set_error(err)),
1316        };
1317        match solution_set {
1318            SolutionSet::None => Ok(SolutionSet::None),
1319            SolutionSet::Unique(canonical_impl) => {
1320                Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
1321            }
1322            SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
1323        }
1324    }
1325
1326    // Error handling methods
1327    // ======================
1328
1329    /// Sets an error in the inference state.
1330    /// Does nothing if an error is already set.
1331    /// Returns an `ErrorSet` that can be used in reporting the error.
1332    pub fn set_error(&mut self, err: InferenceError<'db>) -> ErrorSet {
1333        self.set_error_ex(err, None)
1334    }
1335
1336    /// Sets an error in the inference state, with an optional location for the diagnostics
1337    /// reporting. Does nothing if an error is already set.
1338    /// Returns an `ErrorSet` that can be used in reporting the error.
1339    pub fn set_error_ex(
1340        &mut self,
1341        err: InferenceError<'db>,
1342        stable_ptr: Option<SyntaxStablePtrId<'db>>,
1343    ) -> ErrorSet {
1344        if self.error_status.is_err() {
1345            return ErrorSet;
1346        }
1347        self.error_status = Err(if let InferenceError::Reported(diag_added) = err {
1348            InferenceErrorStatus::Consumed(diag_added)
1349        } else {
1350            InferenceErrorStatus::Pending(PendingInferenceError { err, stable_ptr })
1351        });
1352        ErrorSet
1353    }
1354
1355    /// Sets an error in the inference state, with a var to fetch location for the diagnostics
1356    /// reporting. Does nothing if an error is already set.
1357    /// Returns an `ErrorSet` that can be used in reporting the error.
1358    pub fn set_error_on_var(&mut self, err: InferenceError<'db>, var: InferenceVar) -> ErrorSet {
1359        self.set_error_ex(err, self.stable_ptrs.get(&var).cloned())
1360    }
1361
1362    /// Returns whether an error is set (either pending or consumed).
1363    pub fn is_error_set(&self) -> InferenceResult<()> {
1364        self.error_status.as_ref().copied().map_err(|_| ErrorSet)
1365    }
1366
1367    /// If there is no stable ptr for the pending error, add it by the given var.
1368    fn add_error_stable_ptr(&mut self, var: InferenceVar) {
1369        let var_stable_ptr = self.stable_ptrs.get(&var).copied();
1370        if let Err(InferenceErrorStatus::Pending(PendingInferenceError { err: _, stable_ptr })) =
1371            &mut self.error_status
1372            && stable_ptr.is_none()
1373        {
1374            *stable_ptr = var_stable_ptr;
1375        }
1376    }
1377
1378    /// Consumes the error but doesn't report it. If there is no error, or the error is consumed,
1379    /// returns None. This should be used with caution. Always prefer to use
1380    /// (1) `report_on_pending_error` if possible, or (2) `consume_reported_error` which is safer.
1381    ///
1382    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1383    pub fn consume_error_without_reporting(
1384        &mut self,
1385        err_set: ErrorSet,
1386    ) -> Option<InferenceError<'db>> {
1387        Some(self.consume_error_inner(err_set, skip_diagnostic())?.err)
1388    }
1389
1390    /// Consumes the error that is already reported. If there is no error, or the error is consumed,
1391    /// does nothing. This should be used with caution. Always prefer to use
1392    /// `report_on_pending_error` if possible.
1393    ///
1394    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1395    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1396    pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1397        self.consume_error_inner(err_set, diag_added);
1398    }
1399
1400    /// Consumes the error and returns it, but doesn't report it. If there is no error, or the error
1401    /// is already consumed, returns None. This should be used with caution. Always prefer to use
1402    /// `report_on_pending_error` if possible.
1403    ///
1404    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1405    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1406    fn consume_error_inner(
1407        &mut self,
1408        _err_set: ErrorSet,
1409        diag_added: DiagnosticAdded,
1410    ) -> Option<PendingInferenceError<'db>> {
1411        match &mut self.error_status {
1412            Err(InferenceErrorStatus::Pending(error)) => {
1413                let pending_error = std::mem::replace(
1414                    error,
1415                    PendingInferenceError {
1416                        err: InferenceError::Reported(diag_added),
1417                        stable_ptr: None,
1418                    },
1419                );
1420                self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
1421                Some(pending_error)
1422            }
1423            // TODO(orizi): `panic!("consume_error when there is no pending error")` instead.
1424            _ => None,
1425        }
1426    }
1427
1428    /// Consumes the pending error, if any, and reports it.
1429    /// Should only be called when an error is set, otherwise it panics.
1430    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1431    /// If an error was set but it's already consumed, it doesn't report it again but returns the
1432    /// stored `DiagnosticAdded`.
1433    pub fn report_on_pending_error(
1434        &mut self,
1435        _err_set: ErrorSet,
1436        diagnostics: &mut SemanticDiagnostics<'db>,
1437        stable_ptr: SyntaxStablePtrId<'db>,
1438    ) -> DiagnosticAdded {
1439        let Err(state_error) = &self.error_status else {
1440            panic!("report_on_pending_error should be called only on error");
1441        };
1442        match state_error {
1443            InferenceErrorStatus::Consumed(diag_added) => *diag_added,
1444            InferenceErrorStatus::Pending(pending) => {
1445                let diag_added = match &pending.err {
1446                    InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1447                        // If we have other diagnostics, there is no need to TypeNotInferred.
1448
1449                        // Note that `diagnostics` is not empty, so it is safe to return
1450                        // 'DiagnosticAdded' here.
1451                        skip_diagnostic()
1452                    }
1453                    diag => diag.report(diagnostics, pending.stable_ptr.unwrap_or(stable_ptr)),
1454                };
1455                self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
1456                diag_added
1457            }
1458        }
1459    }
1460
1461    /// If the current status is of a pending error, reports an alternative diagnostic, by calling
1462    /// `report`, and consumes the error. Otherwise, does nothing.
1463    pub fn report_modified_if_pending(
1464        &mut self,
1465        err_set: ErrorSet,
1466        report: impl FnOnce() -> DiagnosticAdded,
1467    ) {
1468        if matches!(self.error_status, Err(InferenceErrorStatus::Pending { .. })) {
1469            self.consume_reported_error(err_set, report());
1470        }
1471    }
1472}
1473
1474impl<'a, 'mt> HasDb<&'a dyn Database> for Inference<'a, 'mt> {
1475    fn get_db(&self) -> &'a dyn Database {
1476        self.db
1477    }
1478}
1479add_basic_rewrites!(<'a, 'mt>, Inference<'a, 'mt>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue NegativeImplLongId NegativeImplId);
1480add_expr_rewrites!(<'a, 'mt>, Inference<'a, 'mt>, NoError, @exclude);
1481add_rewrite!(<'a, 'mt>, Inference<'a, 'mt>, NoError, Ambiguity<'a>);
1482impl<'db, 'mt> SemanticRewriter<TypeId<'db>, NoError> for Inference<'db, 'mt> {
1483    fn internal_rewrite(&mut self, value: &mut TypeId<'db>) -> Result<RewriteResult, NoError> {
1484        if value.is_var_free(self.db) {
1485            return Ok(RewriteResult::NoChange);
1486        }
1487        value.default_rewrite(self)
1488    }
1489}
1490impl<'db, 'mt> SemanticRewriter<ImplId<'db>, NoError> for Inference<'db, 'mt> {
1491    fn internal_rewrite(&mut self, value: &mut ImplId<'db>) -> Result<RewriteResult, NoError> {
1492        if value.is_var_free(self.db) {
1493            return Ok(RewriteResult::NoChange);
1494        }
1495        value.default_rewrite(self)
1496    }
1497}
1498impl<'db, 'mt> SemanticRewriter<NegativeImplId<'db>, NoError> for Inference<'db, 'mt> {
1499    fn internal_rewrite(
1500        &mut self,
1501        value: &mut NegativeImplId<'db>,
1502    ) -> Result<RewriteResult, NoError> {
1503        if value.is_var_free(self.db) {
1504            return Ok(RewriteResult::NoChange);
1505        }
1506        value.default_rewrite(self)
1507    }
1508}
1509
1510impl<'db, 'mt> SemanticRewriter<TypeLongId<'db>, NoError> for Inference<'db, 'mt> {
1511    fn internal_rewrite(&mut self, value: &mut TypeLongId<'db>) -> Result<RewriteResult, NoError> {
1512        match value {
1513            TypeLongId::Var(var) => {
1514                if let Some(type_id) = self.type_assignment.get(&var.id) {
1515                    let mut long_type_id = type_id.long(self.db).clone();
1516                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1517                        *self.type_assignment.get_mut(&var.id).unwrap() =
1518                            long_type_id.clone().intern(self.db);
1519                    }
1520                    *value = long_type_id;
1521                    return Ok(RewriteResult::Modified);
1522                }
1523            }
1524            TypeLongId::ImplType(impl_type_id) => {
1525                if let Some(type_id) = self.impl_type_bounds.get(&((*impl_type_id).into())) {
1526                    *value = type_id.long(self.db).clone();
1527                    self.internal_rewrite(value)?;
1528                    return Ok(RewriteResult::Modified);
1529                }
1530                let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1531                let impl_id = impl_type_id.impl_id();
1532                let trait_ty = impl_type_id.ty();
1533                return Ok(match impl_id.long(self.db) {
1534                    ImplLongId::GenericParameter(_)
1535                    | ImplLongId::SelfImpl(_)
1536                    | ImplLongId::ImplImpl(_) => impl_type_id_rewrite_result,
1537                    ImplLongId::Concrete(_) => {
1538                        if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1539                            impl_id, trait_ty, self.db,
1540                        )) {
1541                            *value = self.rewrite(ty).no_err().long(self.db).clone();
1542                            RewriteResult::Modified
1543                        } else {
1544                            impl_type_id_rewrite_result
1545                        }
1546                    }
1547                    ImplLongId::ImplVar(var) => {
1548                        *value = self.rewritten_impl_type(*var, trait_ty).long(self.db).clone();
1549                        return Ok(RewriteResult::Modified);
1550                    }
1551                    ImplLongId::GeneratedImpl(generated) => {
1552                        *value = self
1553                            .rewrite(
1554                                *generated
1555                                    .long(self.db)
1556                                    .impl_items
1557                                    .0
1558                                    .get(&impl_type_id.ty())
1559                                    .unwrap(),
1560                            )
1561                            .no_err()
1562                            .long(self.db)
1563                            .clone();
1564                        RewriteResult::Modified
1565                    }
1566                });
1567            }
1568            _ => {}
1569        }
1570        value.default_rewrite(self)
1571    }
1572}
1573impl<'db, 'mt> SemanticRewriter<ConstValue<'db>, NoError> for Inference<'db, 'mt> {
1574    fn internal_rewrite(&mut self, value: &mut ConstValue<'db>) -> Result<RewriteResult, NoError> {
1575        match value {
1576            ConstValue::Var(var, _) => {
1577                return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1578                    let mut const_value = const_value_id.long(self.db).clone();
1579                    if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1580                        *self.const_assignment.get_mut(&var.id).unwrap() =
1581                            const_value.clone().intern(self.db);
1582                    }
1583                    *value = const_value;
1584                    RewriteResult::Modified
1585                } else {
1586                    RewriteResult::NoChange
1587                });
1588            }
1589            ConstValue::ImplConstant(impl_constant_id) => {
1590                let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1591                let impl_id = impl_constant_id.impl_id();
1592                let trait_constant = impl_constant_id.trait_constant_id();
1593                return Ok(match impl_id.long(self.db) {
1594                    ImplLongId::GenericParameter(_)
1595                    | ImplLongId::SelfImpl(_)
1596                    | ImplLongId::GeneratedImpl(_)
1597                    | ImplLongId::ImplImpl(_) => impl_constant_id_rewrite_result,
1598                    ImplLongId::Concrete(_) => {
1599                        if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1600                            ImplConstantId::new(impl_id, trait_constant, self.db),
1601                        ) {
1602                            *value = self.rewrite(constant).no_err().long(self.db).clone();
1603                            RewriteResult::Modified
1604                        } else {
1605                            impl_constant_id_rewrite_result
1606                        }
1607                    }
1608                    ImplLongId::ImplVar(var) => {
1609                        *value = self
1610                            .rewritten_impl_constant(*var, trait_constant)
1611                            .long(self.db)
1612                            .clone();
1613                        return Ok(RewriteResult::Modified);
1614                    }
1615                });
1616            }
1617            _ => {}
1618        }
1619        value.default_rewrite(self)
1620    }
1621}
1622impl<'db, 'mt> SemanticRewriter<ImplLongId<'db>, NoError> for Inference<'db, 'mt> {
1623    fn internal_rewrite(&mut self, value: &mut ImplLongId<'db>) -> Result<RewriteResult, NoError> {
1624        match value {
1625            ImplLongId::ImplVar(var) => {
1626                let long_id = var.long(self.db);
1627                // Relax the candidates.
1628                let impl_var_id = long_id.id;
1629                if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1630                    let mut long_impl_id = impl_id.long(self.db).clone();
1631                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1632                        *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1633                            long_impl_id.clone().intern(self.db);
1634                    }
1635                    *value = long_impl_id;
1636                    return Ok(RewriteResult::Modified);
1637                }
1638            }
1639            ImplLongId::ImplImpl(impl_impl_id) => {
1640                let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1641                let impl_id = impl_impl_id.impl_id();
1642                return Ok(match impl_id.long(self.db) {
1643                    ImplLongId::GenericParameter(_)
1644                    | ImplLongId::SelfImpl(_)
1645                    | ImplLongId::GeneratedImpl(_)
1646                    | ImplLongId::ImplImpl(_) => impl_impl_id_rewrite_result,
1647                    ImplLongId::Concrete(_) => {
1648                        if let Ok(imp) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1649                            *value = self.rewrite(imp).no_err().long(self.db).clone();
1650                            RewriteResult::Modified
1651                        } else {
1652                            impl_impl_id_rewrite_result
1653                        }
1654                    }
1655                    ImplLongId::ImplVar(var) => {
1656                        if let Ok(concrete_trait_impl) =
1657                            impl_impl_id.concrete_trait_impl_id(self.db)
1658                        {
1659                            *value = self
1660                                .rewritten_impl_impl(*var, concrete_trait_impl)
1661                                .long(self.db)
1662                                .clone();
1663                            return Ok(RewriteResult::Modified);
1664                        } else {
1665                            impl_impl_id_rewrite_result
1666                        }
1667                    }
1668                });
1669            }
1670
1671            _ => {}
1672        }
1673        if value.is_var_free(self.db) {
1674            return Ok(RewriteResult::NoChange);
1675        }
1676        value.default_rewrite(self)
1677    }
1678}
1679
1680impl<'db, 'mt> SemanticRewriter<NegativeImplLongId<'db>, NoError> for Inference<'db, 'mt> {
1681    fn internal_rewrite(
1682        &mut self,
1683        value: &mut NegativeImplLongId<'db>,
1684    ) -> Result<RewriteResult, NoError> {
1685        if let NegativeImplLongId::NegativeImplVar(var) = value {
1686            let long_id = var.long(self.db);
1687            // Relax the candidates.
1688            let neg_impl_var_id = long_id.id;
1689            if let Some(impl_id) = self.negative_impl_assignment(neg_impl_var_id) {
1690                let mut long_neg_impl_id = impl_id.long(self.db).clone();
1691                if let RewriteResult::Modified = self.internal_rewrite(&mut long_neg_impl_id)? {
1692                    *self.negative_impl_assignment.get_mut(&neg_impl_var_id).unwrap() =
1693                        long_neg_impl_id.clone().intern(self.db);
1694                }
1695                *value = long_neg_impl_id;
1696                return Ok(RewriteResult::Modified);
1697            }
1698        }
1699
1700        if value.is_var_free(self.db) {
1701            return Ok(RewriteResult::NoChange);
1702        }
1703        value.default_rewrite(self)
1704    }
1705}
1706
1707struct InferenceIdReplacer<'a> {
1708    db: &'a dyn Database,
1709    from_inference_id: InferenceId<'a>,
1710    to_inference_id: InferenceId<'a>,
1711}
1712impl<'a> InferenceIdReplacer<'a> {
1713    fn new(
1714        db: &'a dyn Database,
1715        from_inference_id: InferenceId<'a>,
1716        to_inference_id: InferenceId<'a>,
1717    ) -> Self {
1718        Self { db, from_inference_id, to_inference_id }
1719    }
1720}
1721impl<'a> HasDb<&'a dyn Database> for InferenceIdReplacer<'a> {
1722    fn get_db(&self) -> &'a dyn Database {
1723        self.db
1724    }
1725}
1726add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1727add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1728add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity<'a>);
1729impl<'a> SemanticRewriter<InferenceId<'a>, NoError> for InferenceIdReplacer<'a> {
1730    fn internal_rewrite(&mut self, value: &mut InferenceId<'a>) -> Result<RewriteResult, NoError> {
1731        if value == &self.from_inference_id {
1732            *value = self.to_inference_id;
1733            Ok(RewriteResult::Modified)
1734        } else {
1735            Ok(RewriteResult::NoChange)
1736        }
1737    }
1738}