cairo_lang_semantic/expr/
inference.rs

1//! Bidirectional type inference.
2
3use std::collections::{BTreeMap, HashMap};
4use std::hash::Hash;
5use std::ops::{Deref, DerefMut};
6use std::sync::Arc;
7
8use cairo_lang_debug::DebugWithDb;
9use cairo_lang_defs::ids::{
10    ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId,
11    GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LocalVarId, LookupItemId,
12    MacroCallId, MemberId, NamedLanguageElementId, ParamId, StructId, TraitConstantId,
13    TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
14};
15use cairo_lang_diagnostics::{DiagnosticAdded, skip_diagnostic};
16use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
17use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
18use cairo_lang_utils::deque::Deque;
19use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
20use cairo_lang_utils::{Intern, define_short_id, extract_matches};
21use salsa::Database;
22
23use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
24use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
25use crate::corelib::CorelibSemantic;
26use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
27use crate::expr::inference::canonic::ResultNoErrEx;
28use crate::expr::inference::conform::InferenceConform;
29use crate::expr::inference::solver::SemanticSolver;
30use crate::expr::objects::*;
31use crate::expr::pattern::*;
32use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
33use crate::items::functions::{
34    ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
35    GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
36    ImplGenericFunctionWithBodyId,
37};
38use crate::items::generics::{
39    GenericParamConst, GenericParamImpl, GenericParamSemantic, GenericParamType,
40};
41use crate::items::imp::{
42    GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
43    ImplLookupContextId, ImplSemantic, NegativeImplId, NegativeImplLongId,
44    UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
45};
46use crate::items::trt::{
47    ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId, ConcreteTraitTypeId,
48    ConcreteTraitTypeLongId,
49};
50use crate::substitution::{HasDb, RewriteResult, SemanticRewriter};
51use crate::types::{
52    ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
53    ImplTypeById, ImplTypeId,
54};
55use crate::{
56    ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
57    ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
58    FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
59    Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
60    add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
61};
62
63pub mod canonic;
64pub mod conform;
65pub mod infers;
66pub mod solver;
67
68/// A type variable, created when a generic type argument is not passed, and thus is not known
69/// yet and needs to be inferred.
70#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, salsa::Update)]
71pub struct TypeVar<'db> {
72    pub inference_id: InferenceId<'db>,
73    pub id: LocalTypeVarId,
74}
75
76/// A const variable, created when a generic const argument is not passed, and thus is not known
77/// yet and needs to be inferred.
78#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, salsa::Update)]
79pub struct ConstVar<'db> {
80    pub inference_id: InferenceId<'db>,
81    pub id: LocalConstVarId,
82}
83
84/// An id for an inference context. Each inference variable is associated with an inference id.
85#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, salsa::Update)]
86#[debug_db(dyn Database)]
87pub enum InferenceId<'db> {
88    LookupItemDeclaration(LookupItemId<'db>),
89    LookupItemGenerics(LookupItemId<'db>),
90    LookupItemDefinition(LookupItemId<'db>),
91    ImplDefTrait(ImplDefId<'db>),
92    ImplAliasImplDef(ImplAliasId<'db>),
93    GenericParam(GenericParamId<'db>),
94    GenericImplParamTrait(GenericParamId<'db>),
95    GlobalUseStar(GlobalUseId<'db>),
96    MacroCall(MacroCallId<'db>),
97    Canonical,
98    /// For resolving that will not be used anywhere in the semantic model.
99    NoContext,
100}
101
102/// An impl variable, created when a generic type argument is not passed, and thus is not known
103/// yet and needs to be inferred.
104#[derive(Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, salsa::Update)]
105#[debug_db(dyn Database)]
106pub struct ImplVar<'db> {
107    pub inference_id: InferenceId<'db>,
108    #[dont_rewrite]
109    pub id: LocalImplVarId,
110    pub concrete_trait_id: ConcreteTraitId<'db>,
111    #[dont_rewrite]
112    pub lookup_context: ImplLookupContextId<'db>,
113}
114impl<'db> ImplVar<'db> {
115    pub fn intern(&self, db: &'db dyn Database) -> ImplVarId<'db> {
116        self.clone().intern(db)
117    }
118}
119
120/// A negative impl variable
121#[derive(Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, salsa::Update)]
122#[debug_db(dyn Database)]
123pub struct NegativeImplVar<'db> {
124    pub inference_id: InferenceId<'db>,
125    #[dont_rewrite]
126    pub id: LocalNegativeImplVarId,
127    pub concrete_trait_id: ConcreteTraitId<'db>,
128    #[dont_rewrite]
129    pub lookup_context: ImplLookupContextId<'db>,
130}
131
132#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, salsa::Update)]
133pub struct LocalTypeVarId(pub usize);
134#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, salsa::Update)]
135pub struct LocalImplVarId(pub usize);
136
137#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, salsa::Update)]
138pub struct LocalNegativeImplVarId(pub usize);
139
140#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, salsa::Update)]
141pub struct LocalConstVarId(pub usize);
142
143define_short_id!(ImplVarId, ImplVar<'db>);
144impl<'db> ImplVarId<'db> {
145    pub fn id(&self, db: &dyn Database) -> LocalImplVarId {
146        self.long(db).id
147    }
148    pub fn concrete_trait_id(&self, db: &'db dyn Database) -> ConcreteTraitId<'db> {
149        self.long(db).concrete_trait_id
150    }
151    pub fn lookup_context(&self, db: &'db dyn Database) -> ImplLookupContextId<'db> {
152        self.long(db).lookup_context
153    }
154}
155semantic_object_for_id!(ImplVarId, ImplVar<'a>);
156
157define_short_id!(NegativeImplVarId, NegativeImplVar<'db>);
158impl<'db> NegativeImplVarId<'db> {
159    pub fn id(&self, db: &dyn Database) -> LocalNegativeImplVarId {
160        self.long(db).id
161    }
162    pub fn concrete_trait_id(&self, db: &'db dyn Database) -> ConcreteTraitId<'db> {
163        self.long(db).concrete_trait_id
164    }
165    pub fn lookup_context(&self, db: &'db dyn Database) -> ImplLookupContextId<'db> {
166        self.long(db).lookup_context
167    }
168}
169semantic_object_for_id!(NegativeImplVarId, NegativeImplVar<'a>);
170
171#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject, salsa::Update)]
172pub enum InferenceVar {
173    Type(LocalTypeVarId),
174    Const(LocalConstVarId),
175    Impl(LocalImplVarId),
176    NegativeImpl(LocalNegativeImplVarId),
177}
178
179// TODO(spapini): Add to diagnostics.
180#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb, salsa::Update)]
181#[debug_db(dyn Database)]
182pub enum InferenceError<'db> {
183    /// An inference error wrapping a previously reported error.
184    Reported(DiagnosticAdded),
185    Cycle(InferenceVar),
186    TypeKindMismatch {
187        ty0: TypeId<'db>,
188        ty1: TypeId<'db>,
189    },
190    ConstKindMismatch {
191        const0: ConstValueId<'db>,
192        const1: ConstValueId<'db>,
193    },
194    ImplKindMismatch {
195        impl0: ImplId<'db>,
196        impl1: ImplId<'db>,
197    },
198    NegativeImplKindMismatch {
199        impl0: NegativeImplId<'db>,
200        impl1: NegativeImplId<'db>,
201    },
202    GenericArgMismatch {
203        garg0: GenericArgumentId<'db>,
204        garg1: GenericArgumentId<'db>,
205    },
206    TraitMismatch {
207        trt0: TraitId<'db>,
208        trt1: TraitId<'db>,
209    },
210    ImplTypeMismatch {
211        impl_id: ImplId<'db>,
212        trait_type_id: TraitTypeId<'db>,
213        ty0: TypeId<'db>,
214        ty1: TypeId<'db>,
215    },
216    GenericFunctionMismatch {
217        func0: GenericFunctionId<'db>,
218        func1: GenericFunctionId<'db>,
219    },
220    ConstNotInferred,
221    // TODO(spapini): These are only used for external interface. Separate them along with the
222    // finalize() function to a wrapper.
223    NoImplsFound(ConcreteTraitId<'db>),
224    NoNegativeImplsFound(ConcreteTraitId<'db>),
225    Ambiguity(Ambiguity<'db>),
226    TypeNotInferred(TypeId<'db>),
227}
228impl<'db> InferenceError<'db> {
229    pub fn format(&self, db: &dyn Database) -> String {
230        match self {
231            InferenceError::Reported(_) => "Inference error occurred.".into(),
232            InferenceError::Cycle(_var) => "Inference cycle detected".into(),
233            InferenceError::TypeKindMismatch { ty0, ty1 } => {
234                format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
235            }
236            InferenceError::ConstKindMismatch { const0, const1 } => {
237                format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
238            }
239            InferenceError::ImplKindMismatch { impl0, impl1 } => {
240                format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
241            }
242            InferenceError::NegativeImplKindMismatch { impl0, impl1 } => {
243                format!(
244                    "Negative impl mismatch: `{:?}` and `{:?}`.",
245                    impl0.debug(db),
246                    impl1.debug(db)
247                )
248            }
249            InferenceError::GenericArgMismatch { garg0, garg1 } => {
250                format!(
251                    "Generic arg mismatch: `{:?}` and `{:?}`.",
252                    garg0.debug(db),
253                    garg1.debug(db)
254                )
255            }
256            InferenceError::TraitMismatch { trt0, trt1 } => {
257                format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
258            }
259            InferenceError::ConstNotInferred => "Failed to infer constant.".into(),
260            InferenceError::NoImplsFound(concrete_trait_id) => {
261                let info = db.core_info();
262                let trait_id = concrete_trait_id.trait_id(db);
263                if trait_id == info.numeric_literal_trt {
264                    let generic_type = extract_matches!(
265                        concrete_trait_id.generic_args(db)[0],
266                        GenericArgumentId::Type
267                    );
268                    return format!(
269                        "Mismatched types. The type `{:?}` cannot be created from a numeric \
270                         literal.",
271                        generic_type.debug(db)
272                    );
273                } else if trait_id == info.string_literal_trt {
274                    let generic_type = extract_matches!(
275                        concrete_trait_id.generic_args(db)[0],
276                        GenericArgumentId::Type
277                    );
278                    return format!(
279                        "Mismatched types. The type `{:?}` cannot be created from a string \
280                         literal.",
281                        generic_type.debug(db)
282                    );
283                }
284                format!(
285                    "Trait has no implementation in context: {:?}.",
286                    concrete_trait_id.debug(db)
287                )
288            }
289            InferenceError::NoNegativeImplsFound(concrete_trait_id) => {
290                format!("Trait has implementation in context: {:?}.", concrete_trait_id.debug(db))
291            }
292            InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
293            InferenceError::TypeNotInferred(ty) => {
294                format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
295            }
296            InferenceError::GenericFunctionMismatch { func0, func1 } => {
297                format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
298            }
299            InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 } => {
300                format!(
301                    "`{}::{}` type mismatch: `{:?}` and `{:?}`.",
302                    impl_id.format(db),
303                    trait_type_id.name(db).long(db),
304                    ty0.debug(db),
305                    ty1.debug(db)
306                )
307            }
308        }
309    }
310}
311
312impl<'db> InferenceError<'db> {
313    pub fn report(
314        &self,
315        diagnostics: &mut SemanticDiagnostics<'db>,
316        stable_ptr: SyntaxStablePtrId<'db>,
317    ) -> DiagnosticAdded {
318        match self {
319            InferenceError::Reported(diagnostic_added) => *diagnostic_added,
320            _ => diagnostics
321                .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
322        }
323    }
324}
325
326/// This struct is used to ensure that when an inference error occurs, it is properly set in the
327/// `Inference` object, and then properly consumed.
328///
329/// It must not be constructed directly. Instead, it is returned by [Inference::set_error].
330#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
331pub struct ErrorSet;
332
333pub type InferenceResult<T> = Result<T, ErrorSet>;
334
335#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
336enum InferenceErrorStatus<'db> {
337    Pending(InferenceError<'db>),
338    Consumed(DiagnosticAdded),
339}
340
341/// A mapping of an impl var's trait items to concrete items
342#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, SemanticObject, salsa::Update)]
343pub struct ImplVarTraitItemMappings<'db> {
344    /// The trait types of the impl var.
345    types: OrderedHashMap<TraitTypeId<'db>, TypeId<'db>>,
346    /// The trait constants of the impl var.
347    constants: OrderedHashMap<TraitConstantId<'db>, ConstValueId<'db>>,
348    /// The trait impls of the impl var.
349    impls: OrderedHashMap<TraitImplId<'db>, ImplId<'db>>,
350}
351
352impl ImplVarTraitItemMappings<'_> {
353    /// Returns `true` if the impl var has no mappings.
354    pub fn is_empty(&self) -> bool {
355        self.types.is_empty() && self.constants.is_empty() && self.impls.is_empty()
356    }
357}
358
359/// State of inference.
360#[derive(Debug, DebugWithDb, PartialEq, Eq, salsa::Update)]
361#[debug_db(dyn Database)]
362pub struct InferenceData<'db> {
363    pub inference_id: InferenceId<'db>,
364    /// Current inferred assignment for type variables.
365    pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId<'db>>,
366    /// Current inferred assignment for const variables.
367    pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId<'db>>,
368    /// Current inferred assignment for impl variables.
369    pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId<'db>>,
370    /// Current inferred assignment for negative impl variables.
371    pub negative_impl_assignment: OrderedHashMap<LocalNegativeImplVarId, NegativeImplId<'db>>,
372    /// Unsolved impl variables mapping to a maps of trait items to a corresponding item variable.
373    /// Upon solution of the trait conforms the fully known item to the variable.
374    pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings<'db>>,
375    /// Type variables.
376    pub type_vars: Vec<TypeVar<'db>>,
377    /// Const variables.
378    pub const_vars: Vec<ConstVar<'db>>,
379    /// Impl variables.
380    pub impl_vars: Vec<ImplVar<'db>>,
381    /// Negative impl variables.
382    pub negative_impl_vars: Vec<NegativeImplVar<'db>>,
383    /// Mapping from variables to stable pointers, if exist.
384    pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId<'db>>,
385    /// Inference variables that are pending to be solved.
386    pending: Deque<LocalImplVarId>,
387    /// Inference negative variables that are pending to be solved.
388    negative_pending: Deque<LocalNegativeImplVarId>,
389    /// Inference variables that have been refuted - no solutions exist.
390    refuted: Vec<LocalImplVarId>,
391    /// Inference negative variables that have been refuted - no solutions exist.
392    negative_refuted: Vec<LocalNegativeImplVarId>,
393    /// Inference variables that have been solved.
394    solved: Vec<LocalImplVarId>,
395    /// Inference variables that are currently ambiguous. May be solved later.
396    ambiguous: Vec<(LocalImplVarId, Ambiguity<'db>)>,
397    /// Negative impl inference variables that are currently ambiguous. May be solved later.
398    negative_ambiguous: Vec<(LocalNegativeImplVarId, Ambiguity<'db>)>,
399    /// Mapping from impl types to type variables.
400    pub impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
401
402    /// The current error status.
403    error_status: Result<(), InferenceErrorStatus<'db>>,
404}
405impl<'db> InferenceData<'db> {
406    pub fn new(inference_id: InferenceId<'db>) -> Self {
407        Self {
408            inference_id,
409            type_assignment: OrderedHashMap::default(),
410            impl_assignment: OrderedHashMap::default(),
411            const_assignment: OrderedHashMap::default(),
412            negative_impl_assignment: OrderedHashMap::default(),
413            impl_vars_trait_item_mappings: HashMap::new(),
414            type_vars: Vec::new(),
415            impl_vars: Vec::new(),
416            const_vars: Vec::new(),
417            negative_impl_vars: Vec::new(),
418            stable_ptrs: HashMap::new(),
419            pending: Deque::new(),
420            negative_pending: Deque::new(),
421            refuted: Vec::new(),
422            negative_refuted: Vec::new(),
423            solved: Vec::new(),
424            ambiguous: Vec::new(),
425            negative_ambiguous: Vec::new(),
426            impl_type_bounds: Default::default(),
427            error_status: Ok(()),
428        }
429    }
430    pub fn inference<'r>(&'r mut self, db: &'db dyn Database) -> Inference<'db, 'r> {
431        Inference::new(db, self)
432    }
433    pub fn clone_with_inference_id(
434        &self,
435        db: &'db dyn Database,
436        inference_id: InferenceId<'db>,
437    ) -> InferenceData<'db> {
438        let mut inference_id_replacer =
439            InferenceIdReplacer::new(db, self.inference_id, inference_id);
440        Self {
441            inference_id,
442            type_assignment: self
443                .type_assignment
444                .iter()
445                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
446                .collect(),
447            const_assignment: self
448                .const_assignment
449                .iter()
450                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
451                .collect(),
452            impl_assignment: self
453                .impl_assignment
454                .iter()
455                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
456                .collect(),
457            negative_impl_assignment: self
458                .negative_impl_assignment
459                .iter()
460                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
461                .collect(),
462            impl_vars_trait_item_mappings: self
463                .impl_vars_trait_item_mappings
464                .iter()
465                .map(|(k, mappings)| {
466                    (
467                        *k,
468                        ImplVarTraitItemMappings {
469                            types: mappings
470                                .types
471                                .iter()
472                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
473                                .collect(),
474                            constants: mappings
475                                .constants
476                                .iter()
477                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
478                                .collect(),
479                            impls: mappings
480                                .impls
481                                .iter()
482                                .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
483                                .collect(),
484                        },
485                    )
486                })
487                .collect(),
488            type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
489            const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
490            impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
491            negative_impl_vars: inference_id_replacer
492                .rewrite(self.negative_impl_vars.clone())
493                .no_err(),
494            stable_ptrs: self.stable_ptrs.clone(),
495            pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
496            negative_pending: inference_id_replacer.rewrite(self.negative_pending.clone()).no_err(),
497            refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
498            negative_refuted: inference_id_replacer.rewrite(self.negative_refuted.clone()).no_err(),
499            solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
500            ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
501            negative_ambiguous: inference_id_replacer
502                .rewrite(self.negative_ambiguous.clone())
503                .no_err(),
504            // we do not need to rewrite the impl type bounds, as they all should be var free.
505            impl_type_bounds: self.impl_type_bounds.clone(),
506
507            error_status: self.error_status.clone(),
508        }
509    }
510    pub fn temporary_clone(&self) -> InferenceData<'db> {
511        Self {
512            inference_id: self.inference_id,
513            type_assignment: self.type_assignment.clone(),
514            const_assignment: self.const_assignment.clone(),
515            impl_assignment: self.impl_assignment.clone(),
516            negative_impl_assignment: self.negative_impl_assignment.clone(),
517            impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
518            type_vars: self.type_vars.clone(),
519            const_vars: self.const_vars.clone(),
520            impl_vars: self.impl_vars.clone(),
521            negative_impl_vars: self.negative_impl_vars.clone(),
522            stable_ptrs: self.stable_ptrs.clone(),
523            pending: self.pending.clone(),
524            negative_pending: self.negative_pending.clone(),
525            refuted: self.refuted.clone(),
526            negative_refuted: self.negative_refuted.clone(),
527            solved: self.solved.clone(),
528            ambiguous: self.ambiguous.clone(),
529            negative_ambiguous: self.negative_ambiguous.clone(),
530            impl_type_bounds: self.impl_type_bounds.clone(),
531            error_status: self.error_status.clone(),
532        }
533    }
534}
535
536/// State of inference. A system of inference constraints.
537pub struct Inference<'db, 'id> {
538    db: &'db dyn Database,
539    pub data: &'id mut InferenceData<'db>,
540}
541
542impl<'db, 'id> Deref for Inference<'db, 'id> {
543    type Target = InferenceData<'db>;
544
545    fn deref(&self) -> &Self::Target {
546        self.data
547    }
548}
549impl DerefMut for Inference<'_, '_> {
550    fn deref_mut(&mut self) -> &mut Self::Target {
551        self.data
552    }
553}
554
555impl std::fmt::Debug for Inference<'_, '_> {
556    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557        let x = self.data.debug(self.db);
558        write!(f, "{x:?}")
559    }
560}
561
562impl<'db, 'id> Inference<'db, 'id> {
563    fn new(db: &'db dyn Database, data: &'id mut InferenceData<'db>) -> Self {
564        Self { db, data }
565    }
566
567    /// Getter for an [ImplVar].
568    fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar<'db> {
569        &self.impl_vars[var_id.0]
570    }
571
572    /// Getter for an impl var assignment.
573    pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId<'db>> {
574        self.impl_assignment.get(&var_id).copied()
575    }
576
577    /// Getter for a [NegativeImplVar].
578    fn negative_impl_var(&self, var_id: LocalNegativeImplVarId) -> &NegativeImplVar<'db> {
579        &self.negative_impl_vars[var_id.0]
580    }
581
582    /// Getter for a negative impl var assignment.
583    pub fn negative_impl_assignment(
584        &self,
585        var_id: LocalNegativeImplVarId,
586    ) -> Option<NegativeImplId<'db>> {
587        self.negative_impl_assignment.get(&var_id).copied()
588    }
589
590    /// Getter for a type var assignment.
591    fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId<'db>> {
592        self.type_assignment.get(&var_id).copied()
593    }
594
595    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
596    /// Returns a wrapping TypeId.
597    pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId<'db>>) -> TypeId<'db> {
598        let var = self.new_type_var_raw(stable_ptr);
599
600        TypeLongId::Var(var).intern(self.db)
601    }
602
603    /// Allocates a new [TypeVar] for an unknown type that needs to be inferred.
604    /// Returns the variable id.
605    pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId<'db>>) -> TypeVar<'db> {
606        let var =
607            TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
608        if let Some(stable_ptr) = stable_ptr {
609            self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
610        }
611        self.type_vars.push(var);
612        var
613    }
614
615    /// Sets the infrence's impl type bounds to the given map, and rewrittes the types so all the
616    /// types are var free.
617    pub fn set_impl_type_bounds(
618        &mut self,
619        impl_type_bounds: OrderedHashMap<ImplTypeId<'db>, TypeId<'db>>,
620    ) {
621        let impl_type_bounds_finalized = impl_type_bounds
622            .iter()
623            .filter_map(|(impl_type, ty)| {
624                let rewritten_type = self.rewrite(ty.long(self.db).clone()).no_err();
625                if !matches!(rewritten_type, TypeLongId::Var(_)) {
626                    return Some(((*impl_type).into(), rewritten_type.intern(self.db)));
627                }
628                // conformed the var type to the original impl type to remove it from the pending
629                // list.
630                self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
631                None
632            })
633            .collect();
634
635        self.data.impl_type_bounds = Arc::new(impl_type_bounds_finalized);
636    }
637
638    /// Allocates a new [ConstVar] for an unknown consts that needs to be inferred.
639    /// Returns a wrapping [ConstValueId].
640    pub fn new_const_var(
641        &mut self,
642        stable_ptr: Option<SyntaxStablePtrId<'db>>,
643        ty: TypeId<'db>,
644    ) -> ConstValueId<'db> {
645        let var = self.new_const_var_raw(stable_ptr);
646        ConstValue::Var(var, ty).intern(self.db)
647    }
648
649    /// Allocates a new [ConstVar] for an unknown type that needs to be inferred.
650    /// Returns the variable id.
651    pub fn new_const_var_raw(
652        &mut self,
653        stable_ptr: Option<SyntaxStablePtrId<'db>>,
654    ) -> ConstVar<'db> {
655        let var = ConstVar {
656            inference_id: self.inference_id,
657            id: LocalConstVarId(self.const_vars.len()),
658        };
659        if let Some(stable_ptr) = stable_ptr {
660            self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
661        }
662        self.const_vars.push(var);
663        var
664    }
665
666    /// Allocates a new [ImplVar] for an unknown type that needs to be inferred.
667    /// Returns a wrapping ImplId.
668    pub fn new_impl_var(
669        &mut self,
670        concrete_trait_id: ConcreteTraitId<'db>,
671        stable_ptr: Option<SyntaxStablePtrId<'db>>,
672        lookup_context: ImplLookupContextId<'db>,
673    ) -> ImplId<'db> {
674        let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
675        ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
676    }
677
678    /// Allocates a new [ImplVar] for an unknown impl that needs to be inferred.
679    /// Returns the variable id.
680    fn new_impl_var_raw(
681        &mut self,
682        lookup_context: ImplLookupContextId<'db>,
683        concrete_trait_id: ConcreteTraitId<'db>,
684        stable_ptr: Option<SyntaxStablePtrId<'db>>,
685    ) -> LocalImplVarId {
686        let id = LocalImplVarId(self.impl_vars.len());
687        if let Some(stable_ptr) = stable_ptr {
688            self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
689        }
690        let var =
691            ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
692        self.impl_vars.push(var);
693        self.pending.push_back(id);
694        id
695    }
696
697    /// Allocates a new [NegativeImplVar] for an unknown negative impl that needs to be inferred.
698    /// Returns a wrapping NegativeImplId.
699    pub fn new_negative_impl_var(
700        &mut self,
701        concrete_trait_id: ConcreteTraitId<'db>,
702        stable_ptr: Option<SyntaxStablePtrId<'db>>,
703        lookup_context: ImplLookupContextId<'db>,
704    ) -> NegativeImplId<'db> {
705        let var = self.new_negative_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
706        NegativeImplLongId::NegativeImplVar(self.negative_impl_var(var).clone().intern(self.db))
707            .intern(self.db)
708    }
709
710    /// Allocates a new [NegativeImplVar] for an unknown type that needs to be inferred.
711    /// Returns the variable id.
712    fn new_negative_impl_var_raw(
713        &mut self,
714        lookup_context: ImplLookupContextId<'db>,
715        concrete_trait_id: ConcreteTraitId<'db>,
716        stable_ptr: Option<SyntaxStablePtrId<'db>>,
717    ) -> LocalNegativeImplVarId {
718        let id = LocalNegativeImplVarId(self.negative_impl_vars.len());
719        if let Some(stable_ptr) = stable_ptr {
720            self.stable_ptrs.insert(InferenceVar::NegativeImpl(id), stable_ptr);
721        }
722        let var = NegativeImplVar {
723            inference_id: self.inference_id,
724            id,
725            concrete_trait_id,
726            lookup_context,
727        };
728        self.negative_impl_vars.push(var);
729        self.negative_pending.push_back(id);
730        id
731    }
732
733    /// Solves the inference system. After a successful solve, there are no more pending impl
734    /// inferences.
735    /// Returns whether the inference was successful. If not, the error may be found by
736    /// `.error_state()`.
737    pub fn solve(&mut self) -> InferenceResult<()> {
738        self.solve_ex().map_err(|(err_set, _)| err_set)
739    }
740
741    /// Same as `solve`, but returns the error stable pointer if an error occurred.
742    fn solve_ex(&mut self) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId<'db>>)> {
743        let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
744        self.pending.extend(ambiguous.map(|(var, _)| var));
745        while let Some(var) = self.pending.pop_front() {
746            // First inference error stops inference.
747            self.solve_single_pending(var).map_err(|err_set| {
748                (err_set, self.stable_ptrs.get(&InferenceVar::Impl(var)).copied())
749            })?;
750        }
751        while let Some(var) = self.negative_pending.pop_front() {
752            // First inference error stops inference.
753            self.solve_single_negative_pending(var).map_err(|err_set| {
754                (err_set, self.stable_ptrs.get(&InferenceVar::NegativeImpl(var)).copied())
755            })?;
756        }
757        Ok(())
758    }
759
760    fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
761        if self.impl_assignment.contains_key(&var) {
762            return Ok(());
763        }
764        let solution = match self.impl_var_solution_set(var)? {
765            SolutionSet::None => {
766                self.refuted.push(var);
767                return Ok(());
768            }
769            SolutionSet::Ambiguous(ambiguity) => {
770                self.ambiguous.push((var, ambiguity));
771                return Ok(());
772            }
773            SolutionSet::Unique(solution) => solution,
774        };
775
776        // Solution found. Assign it.
777        self.assign_local_impl(var, solution)?;
778
779        // Something changed.
780        self.solved.push(var);
781        let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
782        self.pending.extend(ambiguous.map(|(var, _)| var));
783
784        let negative_ambiguous = std::mem::take(&mut self.negative_ambiguous).into_iter();
785        self.negative_pending.extend(negative_ambiguous.map(|(var, _)| var));
786
787        Ok(())
788    }
789
790    /// Solves a single negative impl pending variable.
791    fn solve_single_negative_pending(
792        &mut self,
793        var: LocalNegativeImplVarId,
794    ) -> InferenceResult<()> {
795        if self.negative_impl_assignment.contains_key(&var) {
796            return Ok(());
797        }
798
799        let solution = match self.negative_impl_var_solution_set(var)? {
800            SolutionSet::None => {
801                self.negative_refuted.push(var);
802                return Ok(());
803            }
804            SolutionSet::Ambiguous(ambiguity) => {
805                self.negative_ambiguous.push((var, ambiguity));
806                return Ok(());
807            }
808            SolutionSet::Unique(solution) => solution,
809        };
810
811        // Solution found. Assign it.
812        self.assign_local_negative_impl(var, solution)?;
813
814        Ok(())
815    }
816
817    /// Returns the solution set status for the inference:
818    /// Whether there is a unique solution, multiple solutions, no solutions or an error.
819    pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<'db, ()>> {
820        self.solve()?;
821        if !self.refuted.is_empty() {
822            return Ok(SolutionSet::None);
823        }
824        if !self.negative_refuted.is_empty() {
825            return Ok(SolutionSet::None);
826        }
827        if let Some((_, ambiguity)) = self.ambiguous.first() {
828            return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
829        }
830        if let Some((_, ambiguity)) = self.negative_ambiguous.first() {
831            return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
832        }
833        assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
834        assert!(self.negative_pending.is_empty(), "solution() called on an unsolved solver");
835        Ok(SolutionSet::Unique(()))
836    }
837
838    /// Finalizes the inference by inferring uninferred numeric literals as felt252.
839    /// Returns an error and does not report it.
840    pub fn finalize_without_reporting(
841        &mut self,
842    ) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId<'db>>)> {
843        if self.error_status.is_err() {
844            // TODO(yuval): consider adding error location to the set error.
845            return Err((ErrorSet, None));
846        }
847        let info = self.db.core_info();
848        let numeric_trait_id = info.numeric_literal_trt;
849        let felt_ty = info.felt252;
850
851        // Conform all uninferred numeric literals to felt252.
852        loop {
853            let mut changed = false;
854            self.solve_ex()?;
855            for (var, _) in self.ambiguous.clone() {
856                let impl_var = self.impl_var(var).clone();
857                if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
858                    continue;
859                }
860                // Uninferred numeric trait. Resolve as felt252.
861                let ty = extract_matches!(
862                    impl_var.concrete_trait_id.generic_args(self.db)[0],
863                    GenericArgumentId::Type
864                );
865                if self.rewrite(ty).no_err() == felt_ty {
866                    continue;
867                }
868                self.conform_ty(ty, felt_ty).map_err(|err_set| {
869                    (err_set, self.stable_ptrs.get(&InferenceVar::Impl(impl_var.id)).copied())
870                })?;
871                changed = true;
872                break;
873            }
874            if !changed {
875                break;
876            }
877        }
878        assert!(
879            self.pending.is_empty(),
880            "pending should all be solved by this point. Guaranteed by solve()."
881        );
882
883        let Some((var, err)) = self.first_undetermined_variable() else {
884            return Ok(());
885        };
886        Err((self.set_error(err), self.stable_ptrs.get(&var).copied()))
887    }
888
889    /// Finalizes the inference and report diagnostics if there are any errors.
890    /// All the remaining type vars are mapped to the `missing` type, to prevent additional
891    /// diagnostics.
892    pub fn finalize<'m>(
893        &'m mut self,
894        diagnostics: &mut SemanticDiagnostics<'db>,
895        stable_ptr: SyntaxStablePtrId<'db>,
896    ) {
897        if let Err((err_set, err_stable_ptr)) = self.finalize_without_reporting() {
898            let diag = self.report_on_pending_error(
899                err_set,
900                diagnostics,
901                err_stable_ptr.unwrap_or(stable_ptr),
902            );
903
904            let ty_missing = TypeId::missing(self.db, diag);
905            for var in &self.data.type_vars {
906                self.data.type_assignment.entry(var.id).or_insert(ty_missing);
907            }
908        }
909    }
910
911    /// Retrieves the first variable that is still not inferred, or None, if everything is
912    /// inferred.
913    /// Does not set the error but return it, which is ok as this is a private helper function.
914    fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError<'db>)> {
915        if let Some(var) = self.refuted.first().copied() {
916            let impl_var = self.impl_var(var).clone();
917            let concrete_trait_id = impl_var.concrete_trait_id;
918            let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
919            return Some((
920                InferenceVar::Impl(var),
921                InferenceError::NoImplsFound(concrete_trait_id),
922            ));
923        }
924        if let Some(var) = self.negative_refuted.first().copied() {
925            let negative_impl_var = self.negative_impl_var(var).clone();
926            let concrete_trait_id = negative_impl_var.concrete_trait_id;
927            let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
928            return Some((
929                InferenceVar::NegativeImpl(var),
930                InferenceError::NoNegativeImplsFound(concrete_trait_id),
931            ));
932        }
933
934        let mut fallback_ret = None;
935        if let Some((var, ambiguity)) = self.ambiguous.first() {
936            // Note: do not rewrite `ambiguity`, since it is expressed in canonical variables.
937            let ret =
938                Some((InferenceVar::Impl(*var), InferenceError::Ambiguity(ambiguity.clone())));
939            if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
940                return ret;
941            } else {
942                fallback_ret = ret;
943            }
944        }
945        if let Some((var, ambiguity)) = self.negative_ambiguous.first() {
946            let ret = Some((
947                InferenceVar::NegativeImpl(*var),
948                InferenceError::Ambiguity(ambiguity.clone()),
949            ));
950            if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
951                return ret;
952            } else {
953                fallback_ret = ret;
954            }
955        }
956        for (id, var) in self.type_vars.iter().enumerate() {
957            if self.type_assignment(LocalTypeVarId(id)).is_none() {
958                let ty = TypeLongId::Var(*var).intern(self.db);
959                return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
960            }
961        }
962        for (id, var) in self.const_vars.iter().enumerate() {
963            if !self.const_assignment.contains_key(&LocalConstVarId(id)) {
964                let infernence_var = InferenceVar::Const(var.id);
965                return Some((infernence_var, InferenceError::ConstNotInferred));
966            }
967        }
968        fallback_ret
969    }
970
971    /// Assigns a value to a local impl variable id. See assign_impl().
972    fn assign_local_impl(
973        &mut self,
974        var: LocalImplVarId,
975        impl_id: ImplId<'db>,
976    ) -> InferenceResult<ImplId<'db>> {
977        let concrete_trait = impl_id
978            .concrete_trait(self.db)
979            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
980        self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
981        if let Some(other_impl) = self.impl_assignment(var) {
982            return self.conform_impl(impl_id, other_impl);
983        }
984        if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
985        {
986            return Err(self.set_error(InferenceError::Cycle(InferenceVar::Impl(var))));
987        }
988        self.impl_assignment.insert(var, impl_id);
989        if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
990            for (trait_type_id, ty) in mappings.types {
991                let impl_ty = self
992                    .db
993                    .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_type_id, self.db))
994                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
995                if let Err(err_set) = self.conform_ty(ty, impl_ty) {
996                    // Override the error with ImplTypeMismatch.
997                    let ty0 = self.rewrite(ty).no_err();
998                    let ty1 = self.rewrite(impl_ty).no_err();
999
1000                    let error =
1001                        InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 };
1002                    self.error_status = Err(InferenceErrorStatus::Pending(error));
1003                    return Err(err_set);
1004                }
1005            }
1006            for (trait_constant, constant_id) in mappings.constants {
1007                let concrete_impl_constant = self
1008                    .db
1009                    .impl_constant_concrete_implized_value(ImplConstantId::new(
1010                        impl_id,
1011                        trait_constant,
1012                        self.db,
1013                    ))
1014                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1015                self.conform_const(constant_id, concrete_impl_constant)?;
1016            }
1017            for (trait_impl, inner_impl_id) in mappings.impls {
1018                let concrete_impl_impl = self
1019                    .db
1020                    .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
1021                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1022                self.conform_impl(inner_impl_id, concrete_impl_impl)?;
1023            }
1024        }
1025        Ok(impl_id)
1026    }
1027
1028    /// Tries to assigns value to an [ImplVarId]. Return the assigned impl, or an error.
1029    fn assign_impl(
1030        &mut self,
1031        var_id: ImplVarId<'db>,
1032        impl_id: ImplId<'db>,
1033    ) -> InferenceResult<ImplId<'db>> {
1034        let var = var_id.long(self.db);
1035        if var.inference_id != self.inference_id {
1036            return Err(self.set_error(InferenceError::ImplKindMismatch {
1037                impl0: ImplLongId::ImplVar(var_id).intern(self.db),
1038                impl1: impl_id,
1039            }));
1040        }
1041        self.assign_local_impl(var.id, impl_id)
1042    }
1043
1044    /// Assigns a value to a local negative impl variable id. See assign_neg_impl().
1045    fn assign_local_negative_impl(
1046        &mut self,
1047        var: LocalNegativeImplVarId,
1048        neg_impl_id: NegativeImplId<'db>,
1049    ) -> InferenceResult<NegativeImplId<'db>> {
1050        let concrete_trait = neg_impl_id
1051            .concrete_trait(self.db)
1052            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1053        self.conform_traits(self.negative_impl_var(var).concrete_trait_id, concrete_trait)?;
1054        if let Some(other_impl) = self.negative_impl_assignment(var) {
1055            return self.conform_neg_impl(neg_impl_id, other_impl);
1056        }
1057        if !neg_impl_id.is_var_free(self.db)
1058            && self.negative_impl_contains_var(neg_impl_id, InferenceVar::NegativeImpl(var))
1059        {
1060            return Err(self.set_error(InferenceError::Cycle(InferenceVar::NegativeImpl(var))));
1061        }
1062        self.negative_impl_assignment.insert(var, neg_impl_id);
1063        Ok(neg_impl_id)
1064    }
1065
1066    /// Tries to assigns value to an [NegativeImplVarId]. Return the assigned negative impl, or an
1067    /// error.
1068    fn assign_neg_impl(
1069        &mut self,
1070        var_id: NegativeImplVarId<'db>,
1071        neg_impl_id: NegativeImplId<'db>,
1072    ) -> InferenceResult<NegativeImplId<'db>> {
1073        let var = var_id.long(self.db);
1074        if var.inference_id != self.inference_id {
1075            return Err(self.set_error(InferenceError::NegativeImplKindMismatch {
1076                impl0: NegativeImplLongId::NegativeImplVar(var_id).intern(self.db),
1077                impl1: neg_impl_id,
1078            }));
1079        }
1080        self.assign_local_negative_impl(var.id, neg_impl_id)
1081    }
1082
1083    /// Assigns a value to a [TypeVar]. Return the assigned type, or an error.
1084    /// Assumes the variable is not already assigned.
1085    fn assign_ty(&mut self, var: TypeVar<'db>, ty: TypeId<'db>) -> InferenceResult<TypeId<'db>> {
1086        if var.inference_id != self.inference_id {
1087            return Err(self.set_error(InferenceError::TypeKindMismatch {
1088                ty0: TypeLongId::Var(var).intern(self.db),
1089                ty1: ty,
1090            }));
1091        }
1092        assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
1093        let inference_var = InferenceVar::Type(var.id);
1094        if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
1095            return Err(self.set_error(InferenceError::Cycle(inference_var)));
1096        }
1097        // If assigning var to var - making sure assigning to the lower id for proper canonization.
1098        if let TypeLongId::Var(other) = ty.long(self.db)
1099            && other.inference_id == self.inference_id
1100            && other.id.0 > var.id.0
1101        {
1102            let var_ty = TypeLongId::Var(var).intern(self.db);
1103            self.type_assignment.insert(other.id, var_ty);
1104            return Ok(var_ty);
1105        }
1106        self.type_assignment.insert(var.id, ty);
1107        Ok(ty)
1108    }
1109
1110    /// Assigns a value to a [ConstVar]. Return the assigned const, or an error.
1111    /// Assumes the variable is not already assigned.
1112    fn assign_const(
1113        &mut self,
1114        var: ConstVar<'db>,
1115        id: ConstValueId<'db>,
1116    ) -> InferenceResult<ConstValueId<'db>> {
1117        if var.inference_id != self.inference_id {
1118            return Err(self.set_error(InferenceError::ConstKindMismatch {
1119                const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
1120                    .intern(self.db),
1121                const1: id,
1122            }));
1123        }
1124
1125        self.const_assignment.insert(var.id, id);
1126        Ok(id)
1127    }
1128
1129    /// Computes the solution set for an impl variable with a recursive query.
1130    fn impl_var_solution_set(
1131        &mut self,
1132        var: LocalImplVarId,
1133    ) -> InferenceResult<SolutionSet<'db, ImplId<'db>>> {
1134        let impl_var = self.impl_var(var).clone();
1135        // Update the concrete trait of the impl var.
1136        let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
1137        self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
1138        let impl_var_trait_item_mappings =
1139            self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
1140        let solution_set = self.trait_solution_set(
1141            concrete_trait_id,
1142            impl_var_trait_item_mappings,
1143            impl_var.lookup_context,
1144        )?;
1145        Ok(match solution_set {
1146            SolutionSet::None => SolutionSet::None,
1147            SolutionSet::Unique((canonical_impl, canonicalizer)) => {
1148                SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
1149            }
1150            SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
1151        })
1152    }
1153
1154    /// Computes the solution set for an negative impl variable with a recursive query.
1155    fn negative_impl_var_solution_set(
1156        &mut self,
1157        var: LocalNegativeImplVarId,
1158    ) -> InferenceResult<SolutionSet<'db, NegativeImplId<'db>>> {
1159        let negative_impl_var = self.negative_impl_var(var).clone();
1160        let concrete_trait_id = self.rewrite(negative_impl_var.concrete_trait_id).no_err();
1161
1162        let solution_set =
1163            self.validate_no_solution_set(concrete_trait_id, negative_impl_var.lookup_context)?;
1164        Ok(match solution_set {
1165            SolutionSet::Unique(concrete_trait_id) => {
1166                SolutionSet::Unique(NegativeImplLongId::Solved(concrete_trait_id).intern(self.db))
1167            }
1168            SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
1169            SolutionSet::None => SolutionSet::None,
1170        })
1171    }
1172
1173    /// Validates that no solution set is found for the negative impls.
1174    fn validate_no_solution_set(
1175        &mut self,
1176        concrete_trait_id: ConcreteTraitId<'db>,
1177        lookup_context: ImplLookupContextId<'db>,
1178    ) -> InferenceResult<SolutionSet<'db, ConcreteTraitId<'db>>> {
1179        for negative_impl in &lookup_context.long(self.db).negative_impls {
1180            let generic_param = self
1181                .db
1182                .generic_param_semantic(*negative_impl)
1183                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1184            if let GenericParam::NegImpl(neg_impl) = generic_param
1185                && Ok(concrete_trait_id) == neg_impl.concrete_trait
1186            {
1187                return Ok(SolutionSet::Unique(concrete_trait_id));
1188            }
1189        }
1190        for garg in concrete_trait_id.generic_args(self.db) {
1191            let GenericArgumentId::Type(ty) = garg else {
1192                continue;
1193            };
1194            let ty = self.rewrite(*ty).no_err();
1195            // If the negative impl has a generic argument that is not fully
1196            // concrete we can't tell if we should rule out the candidate impl.
1197            // For example if we have -TypeEqual<S, T> we can't tell if S and
1198            // T are going to be assigned the same concrete type.
1199            // We return `SolutionSet::Ambiguous` here to indicate that more
1200            // information is needed.
1201            // Closure can only have one type, even if it's not fully concrete, so can use
1202            // it and not get ambiguity.
1203            if !matches!(ty.long(self.db), TypeLongId::Closure(_)) && !ty.is_fully_concrete(self.db)
1204            {
1205                // TODO(ilya): Try to detect the ambiguity earlier in the
1206                // inference process.
1207                return Ok(SolutionSet::Ambiguous(
1208                    Ambiguity::NegativeImplWithUnresolvedGenericArgs2 { concrete_trait_id, ty },
1209                ));
1210            }
1211        }
1212
1213        if !matches!(
1214            self.trait_solution_set(
1215                concrete_trait_id,
1216                ImplVarTraitItemMappings::default(),
1217                lookup_context,
1218            )?,
1219            SolutionSet::None
1220        ) {
1221            // If a negative impl has an impl, then we should skip it.
1222            return Ok(SolutionSet::None);
1223        }
1224
1225        Ok(SolutionSet::Unique(concrete_trait_id))
1226    }
1227
1228    /// Computes the solution set for a trait with a recursive query.
1229    pub fn trait_solution_set(
1230        &mut self,
1231        concrete_trait_id: ConcreteTraitId<'db>,
1232        impl_var_trait_item_mappings: ImplVarTraitItemMappings<'db>,
1233        lookup_context_id: ImplLookupContextId<'db>,
1234    ) -> InferenceResult<SolutionSet<'db, (CanonicalImpl<'db>, CanonicalMapping<'db>)>> {
1235        let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
1236        // TODO(spapini): This is done twice. Consider doing it only here.
1237        let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
1238        let mut lookup_context = lookup_context_id.long(self.db).clone();
1239        enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
1240
1241        // Don't try to resolve impls if the first generic param is a variable.
1242        let generic_args = concrete_trait_id.generic_args(self.db);
1243        match generic_args.first() {
1244            Some(GenericArgumentId::Type(ty)) => {
1245                if let TypeLongId::Var(_) = ty.long(self.db) {
1246                    // Don't try to infer such impls.
1247                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1248                }
1249            }
1250            Some(GenericArgumentId::Impl(imp)) => {
1251                // Don't try to infer such impls.
1252                if let ImplLongId::ImplVar(_) = imp.long(self.db) {
1253                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1254                }
1255            }
1256            Some(GenericArgumentId::Constant(const_value)) => {
1257                if let ConstValue::Var(_, _) = const_value.long(self.db) {
1258                    // Don't try to infer such impls.
1259                    return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1260                }
1261            }
1262            _ => {}
1263        };
1264        let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
1265            self.db,
1266            self.inference_id,
1267            concrete_trait_id,
1268            impl_var_trait_item_mappings,
1269        );
1270        // impl_type_bounds order is deterimend by the generic params of the function and therefore
1271        // is consistent.
1272        let solution_set = match self.db.canonic_trait_solutions(
1273            canonical_trait,
1274            lookup_context.intern(self.db),
1275            (*self.data.impl_type_bounds).clone(),
1276        ) {
1277            Ok(solution_set) => solution_set,
1278            Err(err) => return Err(self.set_error(err)),
1279        };
1280        match solution_set {
1281            SolutionSet::None => Ok(SolutionSet::None),
1282            SolutionSet::Unique(canonical_impl) => {
1283                Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
1284            }
1285            SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
1286        }
1287    }
1288
1289    // Error handling methods
1290    // ======================
1291
1292    /// Sets an error in the inference state.
1293    /// Does nothing if an error is already set.
1294    /// Returns an `ErrorSet` that can be used in reporting the error.
1295    pub fn set_error(&mut self, err: InferenceError<'db>) -> ErrorSet {
1296        if self.error_status.is_err() {
1297            return ErrorSet;
1298        }
1299        self.error_status = Err(if let InferenceError::Reported(diag_added) = err {
1300            InferenceErrorStatus::Consumed(diag_added)
1301        } else {
1302            InferenceErrorStatus::Pending(err)
1303        });
1304        ErrorSet
1305    }
1306
1307    /// Returns whether an error is set (either pending or consumed).
1308    pub fn is_error_set(&self) -> InferenceResult<()> {
1309        self.error_status.as_ref().copied().map_err(|_| ErrorSet)
1310    }
1311
1312    /// Consumes the error but doesn't report it. If there is no error, or the error is consumed,
1313    /// returns None. This should be used with caution. Always prefer to use
1314    /// (1) `report_on_pending_error` if possible, or (2) `consume_reported_error` which is safer.
1315    ///
1316    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1317    pub fn consume_error_without_reporting(
1318        &mut self,
1319        err_set: ErrorSet,
1320    ) -> Option<InferenceError<'db>> {
1321        self.consume_error_inner(err_set, skip_diagnostic())
1322    }
1323
1324    /// Consumes the error that is already reported. If there is no error, or the error is consumed,
1325    /// does nothing. This should be used with caution. Always prefer to use
1326    /// `report_on_pending_error` if possible.
1327    ///
1328    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1329    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1330    pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1331        self.consume_error_inner(err_set, diag_added);
1332    }
1333
1334    /// Consumes the error and returns it, but doesn't report it. If there is no error, or the error
1335    /// is already consumed, returns None. This should be used with caution. Always prefer to use
1336    /// `report_on_pending_error` if possible.
1337    ///
1338    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1339    /// Gets an `DiagnosticAdded` to "enforce" it is only called when a diagnostic was reported.
1340    fn consume_error_inner(
1341        &mut self,
1342        _err_set: ErrorSet,
1343        diag_added: DiagnosticAdded,
1344    ) -> Option<InferenceError<'db>> {
1345        match &mut self.error_status {
1346            Err(InferenceErrorStatus::Pending(error)) => {
1347                let pending_error = std::mem::replace(error, InferenceError::Reported(diag_added));
1348                self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
1349                Some(pending_error)
1350            }
1351            // TODO(orizi): `panic!("consume_error when there is no pending error")` instead.
1352            _ => None,
1353        }
1354    }
1355
1356    /// Consumes the pending error, if any, and reports it.
1357    /// Should only be called when an error is set, otherwise it panics.
1358    /// Gets an `ErrorSet` to "enforce" it is only called when an error is set.
1359    /// If an error was set but it's already consumed, it doesn't report it again but returns the
1360    /// stored `DiagnosticAdded`.
1361    pub fn report_on_pending_error(
1362        &mut self,
1363        _err_set: ErrorSet,
1364        diagnostics: &mut SemanticDiagnostics<'db>,
1365        stable_ptr: SyntaxStablePtrId<'db>,
1366    ) -> DiagnosticAdded {
1367        let Err(state_error) = &self.error_status else {
1368            panic!("report_on_pending_error should be called only on error");
1369        };
1370        match state_error {
1371            InferenceErrorStatus::Consumed(diag_added) => *diag_added,
1372            InferenceErrorStatus::Pending(error) => {
1373                let diag_added = match error {
1374                    InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1375                        // If we have other diagnostics, there is no need to TypeNotInferred.
1376
1377                        // Note that `diagnostics` is not empty, so it is safe to return
1378                        // 'DiagnosticAdded' here.
1379                        skip_diagnostic()
1380                    }
1381                    diag => diag.report(diagnostics, stable_ptr),
1382                };
1383                self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
1384                diag_added
1385            }
1386        }
1387    }
1388
1389    /// If the current status is of a pending error, reports an alternative diagnostic, by calling
1390    /// `report`, and consumes the error. Otherwise, does nothing.
1391    pub fn report_modified_if_pending(
1392        &mut self,
1393        err_set: ErrorSet,
1394        report: impl FnOnce() -> DiagnosticAdded,
1395    ) {
1396        if matches!(self.error_status, Err(InferenceErrorStatus::Pending(_))) {
1397            self.consume_reported_error(err_set, report());
1398        }
1399    }
1400}
1401
1402impl<'a, 'mt> HasDb<&'a dyn Database> for Inference<'a, 'mt> {
1403    fn get_db(&self) -> &'a dyn Database {
1404        self.db
1405    }
1406}
1407add_basic_rewrites!(<'a, 'mt>, Inference<'a, 'mt>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue NegativeImplLongId NegativeImplId);
1408add_expr_rewrites!(<'a, 'mt>, Inference<'a, 'mt>, NoError, @exclude);
1409add_rewrite!(<'a, 'mt>, Inference<'a, 'mt>, NoError, Ambiguity<'a>);
1410impl<'db, 'mt> SemanticRewriter<TypeId<'db>, NoError> for Inference<'db, 'mt> {
1411    fn internal_rewrite(&mut self, value: &mut TypeId<'db>) -> Result<RewriteResult, NoError> {
1412        if value.is_var_free(self.db) {
1413            return Ok(RewriteResult::NoChange);
1414        }
1415        value.default_rewrite(self)
1416    }
1417}
1418impl<'db, 'mt> SemanticRewriter<ImplId<'db>, NoError> for Inference<'db, 'mt> {
1419    fn internal_rewrite(&mut self, value: &mut ImplId<'db>) -> Result<RewriteResult, NoError> {
1420        if value.is_var_free(self.db) {
1421            return Ok(RewriteResult::NoChange);
1422        }
1423        value.default_rewrite(self)
1424    }
1425}
1426impl<'db, 'mt> SemanticRewriter<NegativeImplId<'db>, NoError> for Inference<'db, 'mt> {
1427    fn internal_rewrite(
1428        &mut self,
1429        value: &mut NegativeImplId<'db>,
1430    ) -> Result<RewriteResult, NoError> {
1431        if value.is_var_free(self.db) {
1432            return Ok(RewriteResult::NoChange);
1433        }
1434        value.default_rewrite(self)
1435    }
1436}
1437
1438impl<'db, 'mt> SemanticRewriter<TypeLongId<'db>, NoError> for Inference<'db, 'mt> {
1439    fn internal_rewrite(&mut self, value: &mut TypeLongId<'db>) -> Result<RewriteResult, NoError> {
1440        match value {
1441            TypeLongId::Var(var) => {
1442                if let Some(type_id) = self.type_assignment.get(&var.id) {
1443                    let mut long_type_id = type_id.long(self.db).clone();
1444                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1445                        *self.type_assignment.get_mut(&var.id).unwrap() =
1446                            long_type_id.clone().intern(self.db);
1447                    }
1448                    *value = long_type_id;
1449                    return Ok(RewriteResult::Modified);
1450                }
1451            }
1452            TypeLongId::ImplType(impl_type_id) => {
1453                if let Some(type_id) = self.impl_type_bounds.get(&((*impl_type_id).into())) {
1454                    *value = type_id.long(self.db).clone();
1455                    self.internal_rewrite(value)?;
1456                    return Ok(RewriteResult::Modified);
1457                }
1458                let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1459                let impl_id = impl_type_id.impl_id();
1460                let trait_ty = impl_type_id.ty();
1461                return Ok(match impl_id.long(self.db) {
1462                    ImplLongId::GenericParameter(_)
1463                    | ImplLongId::SelfImpl(_)
1464                    | ImplLongId::ImplImpl(_) => impl_type_id_rewrite_result,
1465                    ImplLongId::Concrete(_) => {
1466                        if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1467                            impl_id, trait_ty, self.db,
1468                        )) {
1469                            *value = self.rewrite(ty).no_err().long(self.db).clone();
1470                            RewriteResult::Modified
1471                        } else {
1472                            impl_type_id_rewrite_result
1473                        }
1474                    }
1475                    ImplLongId::ImplVar(var) => {
1476                        *value = self.rewritten_impl_type(*var, trait_ty).long(self.db).clone();
1477                        return Ok(RewriteResult::Modified);
1478                    }
1479                    ImplLongId::GeneratedImpl(generated) => {
1480                        *value = self
1481                            .rewrite(
1482                                *generated
1483                                    .long(self.db)
1484                                    .impl_items
1485                                    .0
1486                                    .get(&impl_type_id.ty())
1487                                    .unwrap(),
1488                            )
1489                            .no_err()
1490                            .long(self.db)
1491                            .clone();
1492                        RewriteResult::Modified
1493                    }
1494                });
1495            }
1496            _ => {}
1497        }
1498        value.default_rewrite(self)
1499    }
1500}
1501impl<'db, 'mt> SemanticRewriter<ConstValue<'db>, NoError> for Inference<'db, 'mt> {
1502    fn internal_rewrite(&mut self, value: &mut ConstValue<'db>) -> Result<RewriteResult, NoError> {
1503        match value {
1504            ConstValue::Var(var, _) => {
1505                return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1506                    let mut const_value = const_value_id.long(self.db).clone();
1507                    if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1508                        *self.const_assignment.get_mut(&var.id).unwrap() =
1509                            const_value.clone().intern(self.db);
1510                    }
1511                    *value = const_value;
1512                    RewriteResult::Modified
1513                } else {
1514                    RewriteResult::NoChange
1515                });
1516            }
1517            ConstValue::ImplConstant(impl_constant_id) => {
1518                let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1519                let impl_id = impl_constant_id.impl_id();
1520                let trait_constant = impl_constant_id.trait_constant_id();
1521                return Ok(match impl_id.long(self.db) {
1522                    ImplLongId::GenericParameter(_)
1523                    | ImplLongId::SelfImpl(_)
1524                    | ImplLongId::GeneratedImpl(_)
1525                    | ImplLongId::ImplImpl(_) => impl_constant_id_rewrite_result,
1526                    ImplLongId::Concrete(_) => {
1527                        if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1528                            ImplConstantId::new(impl_id, trait_constant, self.db),
1529                        ) {
1530                            *value = self.rewrite(constant).no_err().long(self.db).clone();
1531                            RewriteResult::Modified
1532                        } else {
1533                            impl_constant_id_rewrite_result
1534                        }
1535                    }
1536                    ImplLongId::ImplVar(var) => {
1537                        *value = self
1538                            .rewritten_impl_constant(*var, trait_constant)
1539                            .long(self.db)
1540                            .clone();
1541                        return Ok(RewriteResult::Modified);
1542                    }
1543                });
1544            }
1545            _ => {}
1546        }
1547        value.default_rewrite(self)
1548    }
1549}
1550impl<'db, 'mt> SemanticRewriter<ImplLongId<'db>, NoError> for Inference<'db, 'mt> {
1551    fn internal_rewrite(&mut self, value: &mut ImplLongId<'db>) -> Result<RewriteResult, NoError> {
1552        match value {
1553            ImplLongId::ImplVar(var) => {
1554                let long_id = var.long(self.db);
1555                // Relax the candidates.
1556                let impl_var_id = long_id.id;
1557                if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1558                    let mut long_impl_id = impl_id.long(self.db).clone();
1559                    if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1560                        *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1561                            long_impl_id.clone().intern(self.db);
1562                    }
1563                    *value = long_impl_id;
1564                    return Ok(RewriteResult::Modified);
1565                }
1566            }
1567            ImplLongId::ImplImpl(impl_impl_id) => {
1568                let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1569                let impl_id = impl_impl_id.impl_id();
1570                return Ok(match impl_id.long(self.db) {
1571                    ImplLongId::GenericParameter(_)
1572                    | ImplLongId::SelfImpl(_)
1573                    | ImplLongId::GeneratedImpl(_)
1574                    | ImplLongId::ImplImpl(_) => impl_impl_id_rewrite_result,
1575                    ImplLongId::Concrete(_) => {
1576                        if let Ok(imp) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1577                            *value = self.rewrite(imp).no_err().long(self.db).clone();
1578                            RewriteResult::Modified
1579                        } else {
1580                            impl_impl_id_rewrite_result
1581                        }
1582                    }
1583                    ImplLongId::ImplVar(var) => {
1584                        if let Ok(concrete_trait_impl) =
1585                            impl_impl_id.concrete_trait_impl_id(self.db)
1586                        {
1587                            *value = self
1588                                .rewritten_impl_impl(*var, concrete_trait_impl)
1589                                .long(self.db)
1590                                .clone();
1591                            return Ok(RewriteResult::Modified);
1592                        } else {
1593                            impl_impl_id_rewrite_result
1594                        }
1595                    }
1596                });
1597            }
1598
1599            _ => {}
1600        }
1601        if value.is_var_free(self.db) {
1602            return Ok(RewriteResult::NoChange);
1603        }
1604        value.default_rewrite(self)
1605    }
1606}
1607
1608impl<'db, 'mt> SemanticRewriter<NegativeImplLongId<'db>, NoError> for Inference<'db, 'mt> {
1609    fn internal_rewrite(
1610        &mut self,
1611        value: &mut NegativeImplLongId<'db>,
1612    ) -> Result<RewriteResult, NoError> {
1613        if let NegativeImplLongId::NegativeImplVar(var) = value {
1614            let long_id = var.long(self.db);
1615            // Relax the candidates.
1616            let neg_impl_var_id = long_id.id;
1617            if let Some(impl_id) = self.negative_impl_assignment(neg_impl_var_id) {
1618                let mut long_neg_impl_id = impl_id.long(self.db).clone();
1619                if let RewriteResult::Modified = self.internal_rewrite(&mut long_neg_impl_id)? {
1620                    *self.negative_impl_assignment.get_mut(&neg_impl_var_id).unwrap() =
1621                        long_neg_impl_id.clone().intern(self.db);
1622                }
1623                *value = long_neg_impl_id;
1624                return Ok(RewriteResult::Modified);
1625            }
1626        }
1627
1628        if value.is_var_free(self.db) {
1629            return Ok(RewriteResult::NoChange);
1630        }
1631        value.default_rewrite(self)
1632    }
1633}
1634
1635struct InferenceIdReplacer<'a> {
1636    db: &'a dyn Database,
1637    from_inference_id: InferenceId<'a>,
1638    to_inference_id: InferenceId<'a>,
1639}
1640impl<'a> InferenceIdReplacer<'a> {
1641    fn new(
1642        db: &'a dyn Database,
1643        from_inference_id: InferenceId<'a>,
1644        to_inference_id: InferenceId<'a>,
1645    ) -> Self {
1646        Self { db, from_inference_id, to_inference_id }
1647    }
1648}
1649impl<'a> HasDb<&'a dyn Database> for InferenceIdReplacer<'a> {
1650    fn get_db(&self) -> &'a dyn Database {
1651        self.db
1652    }
1653}
1654add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1655add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1656add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity<'a>);
1657impl<'a> SemanticRewriter<InferenceId<'a>, NoError> for InferenceIdReplacer<'a> {
1658    fn internal_rewrite(&mut self, value: &mut InferenceId<'a>) -> Result<RewriteResult, NoError> {
1659        if value == &self.from_inference_id {
1660            *value = self.to_inference_id;
1661            Ok(RewriteResult::Modified)
1662        } else {
1663            Ok(RewriteResult::NoChange)
1664        }
1665    }
1666}