1use std::collections::{BTreeMap, HashMap, VecDeque};
4use std::hash::Hash;
5use std::mem;
6use std::ops::{Deref, DerefMut};
7use std::sync::Arc;
8
9use cairo_lang_debug::DebugWithDb;
10use cairo_lang_defs::ids::{
11 ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId,
12 GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LanguageElementId,
13 LocalVarId, LookupItemId, MemberId, NamedLanguageElementId, ParamId, StructId, TraitConstantId,
14 TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
15};
16use cairo_lang_diagnostics::{DiagnosticAdded, Maybe, skip_diagnostic};
17use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
18use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
19use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
20use cairo_lang_utils::{
21 Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches,
22};
23
24use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
25use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
26use crate::db::SemanticGroup;
27use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
28use crate::expr::inference::canonic::ResultNoErrEx;
29use crate::expr::inference::conform::InferenceConform;
30use crate::expr::objects::*;
31use crate::expr::pattern::*;
32use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
33use crate::items::functions::{
34 ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
35 GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
36 ImplGenericFunctionWithBodyId,
37};
38use crate::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType};
39use crate::items::imp::{
40 GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
41 ImplLookupContext, UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
42};
43use crate::items::trt::{
44 ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId, ConcreteTraitTypeId,
45 ConcreteTraitTypeLongId,
46};
47use crate::substitution::{HasDb, RewriteResult, SemanticRewriter};
48use crate::types::{
49 ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
50 ImplTypeById, ImplTypeId,
51};
52use crate::{
53 ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
54 ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
55 FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
56 Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
57 add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
58};
59
60pub mod canonic;
61pub mod conform;
62pub mod infers;
63pub mod solver;
64
65#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
68pub struct TypeVar {
69 pub inference_id: InferenceId,
70 pub id: LocalTypeVarId,
71}
72
73#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
76pub struct ConstVar {
77 pub inference_id: InferenceId,
78 pub id: LocalConstVarId,
79}
80
81#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
83#[debug_db(dyn SemanticGroup + 'static)]
84pub enum InferenceId {
85 LookupItemDeclaration(LookupItemId),
86 LookupItemGenerics(LookupItemId),
87 LookupItemDefinition(LookupItemId),
88 ImplDefTrait(ImplDefId),
89 ImplAliasImplDef(ImplAliasId),
90 GenericParam(GenericParamId),
91 GenericImplParamTrait(GenericParamId),
92 GlobalUseStar(GlobalUseId),
93 Canonical,
94 NoContext,
96}
97
98#[derive(Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
101#[debug_db(dyn SemanticGroup + 'static)]
102pub struct ImplVar {
103 pub inference_id: InferenceId,
104 #[dont_rewrite]
105 pub id: LocalImplVarId,
106 pub concrete_trait_id: ConcreteTraitId,
107 #[dont_rewrite]
108 pub lookup_context: ImplLookupContext,
109}
110impl ImplVar {
111 pub fn intern(&self, db: &dyn SemanticGroup) -> ImplVarId {
112 self.clone().intern(db)
113 }
114}
115
116#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
117pub struct LocalTypeVarId(pub usize);
118#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
119pub struct LocalImplVarId(pub usize);
120
121#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
122pub struct LocalConstVarId(pub usize);
123
124define_short_id!(ImplVarId, ImplVar, SemanticGroup, lookup_intern_impl_var, intern_impl_var);
125impl ImplVarId {
126 pub fn id(&self, db: &dyn SemanticGroup) -> LocalImplVarId {
127 self.lookup_intern(db).id
128 }
129 pub fn concrete_trait_id(&self, db: &dyn SemanticGroup) -> ConcreteTraitId {
130 self.lookup_intern(db).concrete_trait_id
131 }
132 pub fn lookup_context(&self, db: &dyn SemanticGroup) -> ImplLookupContext {
133 self.lookup_intern(db).lookup_context
134 }
135}
136semantic_object_for_id!(ImplVarId, lookup_intern_impl_var, intern_impl_var, ImplVar);
137
138#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject)]
139pub enum InferenceVar {
140 Type(LocalTypeVarId),
141 Const(LocalConstVarId),
142 Impl(LocalImplVarId),
143}
144
145#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb)]
147#[debug_db(dyn SemanticGroup + 'static)]
148pub enum InferenceError {
149 Reported(DiagnosticAdded),
151 Cycle(InferenceVar),
152 TypeKindMismatch {
153 ty0: TypeId,
154 ty1: TypeId,
155 },
156 ConstKindMismatch {
157 const0: ConstValueId,
158 const1: ConstValueId,
159 },
160 ImplKindMismatch {
161 impl0: ImplId,
162 impl1: ImplId,
163 },
164 GenericArgMismatch {
165 garg0: GenericArgumentId,
166 garg1: GenericArgumentId,
167 },
168 TraitMismatch {
169 trt0: TraitId,
170 trt1: TraitId,
171 },
172 ImplTypeMismatch {
173 impl_id: ImplId,
174 trait_type_id: TraitTypeId,
175 ty0: TypeId,
176 ty1: TypeId,
177 },
178 GenericFunctionMismatch {
179 func0: GenericFunctionId,
180 func1: GenericFunctionId,
181 },
182 ConstInferenceNotSupported,
183
184 NoImplsFound(ConcreteTraitId),
187 Ambiguity(Ambiguity),
188 TypeNotInferred(TypeId),
189}
190impl InferenceError {
191 pub fn format(&self, db: &(dyn SemanticGroup + 'static)) -> String {
192 match self {
193 InferenceError::Reported(_) => "Inference error occurred.".into(),
194 InferenceError::Cycle(_var) => "Inference cycle detected".into(),
195 InferenceError::TypeKindMismatch { ty0, ty1 } => {
196 format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
197 }
198 InferenceError::ConstKindMismatch { const0, const1 } => {
199 format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
200 }
201 InferenceError::ImplKindMismatch { impl0, impl1 } => {
202 format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
203 }
204 InferenceError::GenericArgMismatch { garg0, garg1 } => {
205 format!(
206 "Generic arg mismatch: `{:?}` and `{:?}`.",
207 garg0.debug(db),
208 garg1.debug(db)
209 )
210 }
211 InferenceError::TraitMismatch { trt0, trt1 } => {
212 format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
213 }
214 InferenceError::ConstInferenceNotSupported => {
215 "Const generic inference not yet supported.".into()
216 }
217 InferenceError::NoImplsFound(concrete_trait_id) => {
218 let info = db.core_info();
219 let trait_id = concrete_trait_id.trait_id(db);
220 if trait_id == info.numeric_literal_trt {
221 let generic_type = extract_matches!(
222 concrete_trait_id.generic_args(db)[0],
223 GenericArgumentId::Type
224 );
225 return format!(
226 "Mismatched types. The type `{:?}` cannot be created from a numeric \
227 literal.",
228 generic_type.debug(db)
229 );
230 } else if trait_id == info.string_literal_trt {
231 let generic_type = extract_matches!(
232 concrete_trait_id.generic_args(db)[0],
233 GenericArgumentId::Type
234 );
235 return format!(
236 "Mismatched types. The type `{:?}` cannot be created from a string \
237 literal.",
238 generic_type.debug(db)
239 );
240 }
241 format!(
242 "Trait has no implementation in context: {:?}.",
243 concrete_trait_id.debug(db)
244 )
245 }
246 InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
247 InferenceError::TypeNotInferred(ty) => {
248 format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
249 }
250 InferenceError::GenericFunctionMismatch { func0, func1 } => {
251 format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
252 }
253 InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 } => {
254 format!(
255 "`{}::{}` type mismatch: `{:?}` and `{:?}`.",
256 impl_id.format(db),
257 trait_type_id.name(db),
258 ty0.debug(db),
259 ty1.debug(db)
260 )
261 }
262 }
263 }
264}
265
266impl InferenceError {
267 pub fn report(
268 &self,
269 diagnostics: &mut SemanticDiagnostics,
270 stable_ptr: SyntaxStablePtrId,
271 ) -> DiagnosticAdded {
272 match self {
273 InferenceError::Reported(diagnostic_added) => *diagnostic_added,
274 _ => diagnostics
275 .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
276 }
277 }
278}
279
280#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
285pub struct ErrorSet;
286
287pub type InferenceResult<T> = Result<T, ErrorSet>;
288
289#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
290pub enum InferenceErrorStatus {
291 Pending,
292 Consumed,
293}
294
295#[derive(Debug, Default, PartialEq, Eq, Clone, SemanticObject)]
297pub struct ImplVarTraitItemMappings {
298 types: OrderedHashMap<TraitTypeId, TypeId>,
300 constants: OrderedHashMap<TraitConstantId, ConstValueId>,
302 impls: OrderedHashMap<TraitImplId, ImplId>,
304}
305impl Hash for ImplVarTraitItemMappings {
306 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
307 self.types.iter().for_each(|(trait_type_id, type_id)| {
308 trait_type_id.hash(state);
309 type_id.hash(state);
310 });
311 self.constants.iter().for_each(|(trait_const_id, const_id)| {
312 trait_const_id.hash(state);
313 const_id.hash(state);
314 });
315 self.impls.iter().for_each(|(trait_impl_id, impl_id)| {
316 trait_impl_id.hash(state);
317 impl_id.hash(state);
318 });
319 }
320}
321
322#[derive(Debug, DebugWithDb, PartialEq, Eq)]
324#[debug_db(dyn SemanticGroup + 'static)]
325pub struct InferenceData {
326 pub inference_id: InferenceId,
327 pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId>,
329 pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId>,
331 pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId>,
333 pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings>,
336 pub type_vars: Vec<TypeVar>,
338 pub const_vars: Vec<ConstVar>,
340 pub impl_vars: Vec<ImplVar>,
342 pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId>,
344 pending: VecDeque<LocalImplVarId>,
346 refuted: Vec<LocalImplVarId>,
348 solved: Vec<LocalImplVarId>,
350 ambiguous: Vec<(LocalImplVarId, Ambiguity)>,
352 pub impl_type_bounds: Arc<BTreeMap<ImplTypeById, TypeId>>,
354
355 pub error_status: Result<(), InferenceErrorStatus>,
358 error: Option<InferenceError>,
360 consumed_error: Option<DiagnosticAdded>,
362}
363impl InferenceData {
364 pub fn new(inference_id: InferenceId) -> Self {
365 Self {
366 inference_id,
367 type_assignment: OrderedHashMap::default(),
368 impl_assignment: OrderedHashMap::default(),
369 const_assignment: OrderedHashMap::default(),
370 impl_vars_trait_item_mappings: HashMap::new(),
371 type_vars: Vec::new(),
372 impl_vars: Vec::new(),
373 const_vars: Vec::new(),
374 stable_ptrs: HashMap::new(),
375 pending: VecDeque::new(),
376 refuted: Vec::new(),
377 solved: Vec::new(),
378 ambiguous: Vec::new(),
379 impl_type_bounds: Default::default(),
380 error_status: Ok(()),
381 error: None,
382 consumed_error: None,
383 }
384 }
385 pub fn inference<'db, 'b: 'db>(&'db mut self, db: &'b dyn SemanticGroup) -> Inference<'db> {
386 Inference::new(db, self)
387 }
388 pub fn clone_with_inference_id(
389 &self,
390 db: &dyn SemanticGroup,
391 inference_id: InferenceId,
392 ) -> InferenceData {
393 let mut inference_id_replacer =
394 InferenceIdReplacer::new(db, self.inference_id, inference_id);
395 Self {
396 inference_id,
397 type_assignment: self
398 .type_assignment
399 .iter()
400 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
401 .collect(),
402 const_assignment: self
403 .const_assignment
404 .iter()
405 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
406 .collect(),
407 impl_assignment: self
408 .impl_assignment
409 .iter()
410 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
411 .collect(),
412 impl_vars_trait_item_mappings: self
413 .impl_vars_trait_item_mappings
414 .iter()
415 .map(|(k, mappings)| {
416 (
417 *k,
418 ImplVarTraitItemMappings {
419 types: mappings
420 .types
421 .iter()
422 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
423 .collect(),
424 constants: mappings
425 .constants
426 .iter()
427 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
428 .collect(),
429 impls: mappings
430 .impls
431 .iter()
432 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
433 .collect(),
434 },
435 )
436 })
437 .collect(),
438 type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
439 const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
440 impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
441 stable_ptrs: self.stable_ptrs.clone(),
442 pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
443 refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
444 solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
445 ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
446 impl_type_bounds: self.impl_type_bounds.clone(),
448
449 error_status: self.error_status,
450 error: self.error.clone(),
451 consumed_error: self.consumed_error,
452 }
453 }
454 pub fn temporary_clone(&self) -> InferenceData {
455 Self {
456 inference_id: self.inference_id,
457 type_assignment: self.type_assignment.clone(),
458 const_assignment: self.const_assignment.clone(),
459 impl_assignment: self.impl_assignment.clone(),
460 impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
461 type_vars: self.type_vars.clone(),
462 const_vars: self.const_vars.clone(),
463 impl_vars: self.impl_vars.clone(),
464 stable_ptrs: self.stable_ptrs.clone(),
465 pending: self.pending.clone(),
466 refuted: self.refuted.clone(),
467 solved: self.solved.clone(),
468 ambiguous: self.ambiguous.clone(),
469 impl_type_bounds: self.impl_type_bounds.clone(),
470 error_status: self.error_status,
471 error: self.error.clone(),
472 consumed_error: self.consumed_error,
473 }
474 }
475}
476
477pub struct Inference<'db> {
479 db: &'db dyn SemanticGroup,
480 pub data: &'db mut InferenceData,
481}
482
483impl Deref for Inference<'_> {
484 type Target = InferenceData;
485
486 fn deref(&self) -> &Self::Target {
487 self.data
488 }
489}
490impl DerefMut for Inference<'_> {
491 fn deref_mut(&mut self) -> &mut Self::Target {
492 self.data
493 }
494}
495
496impl std::fmt::Debug for Inference<'_> {
497 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
498 let x = self.data.debug(self.db.elongate());
499 write!(f, "{x:?}")
500 }
501}
502
503impl<'db> Inference<'db> {
504 fn new(db: &'db dyn SemanticGroup, data: &'db mut InferenceData) -> Self {
505 Self { db, data }
506 }
507
508 fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar {
510 &self.impl_vars[var_id.0]
511 }
512
513 pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId> {
515 self.impl_assignment.get(&var_id).copied()
516 }
517
518 fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId> {
520 self.type_assignment.get(&var_id).copied()
521 }
522
523 pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeId {
526 let var = self.new_type_var_raw(stable_ptr);
527
528 TypeLongId::Var(var).intern(self.db)
529 }
530
531 pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeVar {
534 let var =
535 TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
536 if let Some(stable_ptr) = stable_ptr {
537 self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
538 }
539 self.type_vars.push(var);
540 var
541 }
542
543 pub fn set_impl_type_bounds(&mut self, impl_type_bounds: OrderedHashMap<ImplTypeId, TypeId>) {
546 let impl_type_bounds_finalized = impl_type_bounds
547 .iter()
548 .filter_map(|(impl_type, ty)| {
549 let rewritten_type = self.rewrite(ty.lookup_intern(self.db)).no_err();
550 if !matches!(rewritten_type, TypeLongId::Var(_)) {
551 return Some(((*impl_type).into(), rewritten_type.intern(self.db)));
552 }
553 self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
556 None
557 })
558 .collect();
559
560 self.data.impl_type_bounds = Arc::new(impl_type_bounds_finalized);
561 }
562
563 pub fn new_const_var(
566 &mut self,
567 stable_ptr: Option<SyntaxStablePtrId>,
568 ty: TypeId,
569 ) -> ConstValueId {
570 let var = self.new_const_var_raw(stable_ptr);
571 ConstValue::Var(var, ty).intern(self.db)
572 }
573
574 pub fn new_const_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> ConstVar {
577 let var = ConstVar {
578 inference_id: self.inference_id,
579 id: LocalConstVarId(self.const_vars.len()),
580 };
581 if let Some(stable_ptr) = stable_ptr {
582 self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
583 }
584 self.const_vars.push(var);
585 var
586 }
587
588 pub fn new_impl_var(
591 &mut self,
592 concrete_trait_id: ConcreteTraitId,
593 stable_ptr: Option<SyntaxStablePtrId>,
594 lookup_context: ImplLookupContext,
595 ) -> ImplId {
596 let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
597 ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
598 }
599
600 fn new_impl_var_raw(
603 &mut self,
604 lookup_context: ImplLookupContext,
605 concrete_trait_id: ConcreteTraitId,
606 stable_ptr: Option<SyntaxStablePtrId>,
607 ) -> LocalImplVarId {
608 let mut lookup_context = lookup_context;
609 lookup_context.insert_module(concrete_trait_id.trait_id(self.db).module_file_id(self.db).0);
610
611 let id = LocalImplVarId(self.impl_vars.len());
612 if let Some(stable_ptr) = stable_ptr {
613 self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
614 }
615 let var =
616 ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
617 self.impl_vars.push(var);
618 self.pending.push_back(id);
619 id
620 }
621
622 pub fn solve(&mut self) -> InferenceResult<()> {
627 self.solve_ex().map_err(|(err_set, _)| err_set)
628 }
629
630 fn solve_ex(&mut self) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
632 let mut ambiguous = std::mem::take(&mut self.ambiguous);
633 self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
634 while let Some(var) = self.pending.pop_front() {
635 self.solve_single_pending(var).map_err(|err_set| {
637 (err_set, self.stable_ptrs.get(&InferenceVar::Impl(var)).copied())
638 })?;
639 }
640 Ok(())
641 }
642
643 fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
644 if self.impl_assignment.contains_key(&var) {
645 return Ok(());
646 }
647 let solution = match self.impl_var_solution_set(var)? {
648 SolutionSet::None => {
649 self.refuted.push(var);
650 return Ok(());
651 }
652 SolutionSet::Ambiguous(ambiguity) => {
653 self.ambiguous.push((var, ambiguity));
654 return Ok(());
655 }
656 SolutionSet::Unique(solution) => solution,
657 };
658
659 self.assign_local_impl(var, solution)?;
661
662 self.solved.push(var);
664 let mut ambiguous = std::mem::take(&mut self.ambiguous);
665 self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
666
667 Ok(())
668 }
669
670 pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<()>> {
673 self.solve()?;
674 if !self.refuted.is_empty() {
675 return Ok(SolutionSet::None);
676 }
677 if let Some((_, ambiguity)) = self.ambiguous.first() {
678 return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
679 }
680 assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
681 Ok(SolutionSet::Unique(()))
682 }
683
684 pub fn finalize_without_reporting(
687 &mut self,
688 ) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
689 if self.error_status.is_err() {
690 return Err((ErrorSet, None));
692 }
693 let info = self.db.core_info();
694 let numeric_trait_id = info.numeric_literal_trt;
695 let felt_ty = info.felt252;
696
697 loop {
699 let mut changed = false;
700 self.solve_ex()?;
701 for (var, _) in self.ambiguous.clone() {
702 let impl_var = self.impl_var(var).clone();
703 if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
704 continue;
705 }
706 let ty = extract_matches!(
708 impl_var.concrete_trait_id.generic_args(self.db)[0],
709 GenericArgumentId::Type
710 );
711 if self.rewrite(ty).no_err() == felt_ty {
712 continue;
713 }
714 self.conform_ty(ty, felt_ty).map_err(|err_set| {
715 (err_set, self.stable_ptrs.get(&InferenceVar::Impl(impl_var.id)).copied())
716 })?;
717 changed = true;
718 break;
719 }
720 if !changed {
721 break;
722 }
723 }
724 assert!(
725 self.pending.is_empty(),
726 "pending should all be solved by this point. Guaranteed by solve()."
727 );
728
729 let Some((var, err)) = self.first_undetermined_variable() else {
730 return Ok(());
731 };
732 Err((self.set_error(err), self.stable_ptrs.get(&var).copied()))
733 }
734
735 pub fn finalize(
739 &mut self,
740 diagnostics: &mut SemanticDiagnostics,
741 stable_ptr: SyntaxStablePtrId,
742 ) {
743 if let Err((err_set, err_stable_ptr)) = self.finalize_without_reporting() {
744 let diag = self.report_on_pending_error(
745 err_set,
746 diagnostics,
747 err_stable_ptr.unwrap_or(stable_ptr),
748 );
749
750 let ty_missing = TypeId::missing(self.db, diag);
751 for var in &self.data.type_vars {
752 self.data.type_assignment.entry(var.id).or_insert(ty_missing);
753 }
754 }
755 }
756
757 fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError)> {
761 if let Some(var) = self.refuted.first().copied() {
762 let impl_var = self.impl_var(var).clone();
763 let concrete_trait_id = impl_var.concrete_trait_id;
764 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
765 return Some((
766 InferenceVar::Impl(var),
767 InferenceError::NoImplsFound(concrete_trait_id),
768 ));
769 }
770 let mut fallback_ret = None;
771 if let Some((var, ambiguity)) = self.ambiguous.first() {
772 let ret =
774 Some((InferenceVar::Impl(*var), InferenceError::Ambiguity(ambiguity.clone())));
775 if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
776 return ret;
777 } else {
778 fallback_ret = ret;
779 }
780 }
781 for (id, var) in self.type_vars.iter().enumerate() {
782 if self.type_assignment(LocalTypeVarId(id)).is_none() {
783 let ty = TypeLongId::Var(*var).intern(self.db);
784 return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
785 }
786 }
787 fallback_ret
788 }
789
790 fn assign_local_impl(
792 &mut self,
793 var: LocalImplVarId,
794 impl_id: ImplId,
795 ) -> InferenceResult<ImplId> {
796 let concrete_trait = impl_id
797 .concrete_trait(self.db)
798 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
799 self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
800 if let Some(other_impl) = self.impl_assignment(var) {
801 return self.conform_impl(impl_id, other_impl);
802 }
803 if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
804 {
805 return Err(self.set_error(InferenceError::Cycle(InferenceVar::Impl(var))));
806 }
807 self.impl_assignment.insert(var, impl_id);
808 if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
809 for (trait_type_id, ty) in mappings.types {
810 let impl_ty = self
811 .db
812 .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_type_id, self.db))
813 .map_err(|_| ErrorSet)?;
814 if let Err(err_set) = self.conform_ty(ty, impl_ty) {
815 let ty0 = self.rewrite(ty).no_err();
817 let ty1 = self.rewrite(impl_ty).no_err();
818
819 self.error =
820 Some(InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 });
821 return Err(err_set);
822 }
823 }
824 for (trait_constant, constant_id) in mappings.constants {
825 self.conform_const(
826 constant_id,
827 self.db
828 .impl_constant_concrete_implized_value(ImplConstantId::new(
829 impl_id,
830 trait_constant,
831 self.db,
832 ))
833 .map_err(|_| ErrorSet)?,
834 )?;
835 }
836 for (trait_impl, inner_impl_id) in mappings.impls {
837 self.conform_impl(
838 inner_impl_id,
839 self.db
840 .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
841 .map_err(|_| ErrorSet)?,
842 )?;
843 }
844 }
845 Ok(impl_id)
846 }
847
848 fn assign_impl(&mut self, var_id: ImplVarId, impl_id: ImplId) -> InferenceResult<ImplId> {
850 let var = var_id.lookup_intern(self.db);
851 if var.inference_id != self.inference_id {
852 return Err(self.set_error(InferenceError::ImplKindMismatch {
853 impl0: ImplLongId::ImplVar(var_id).intern(self.db),
854 impl1: impl_id,
855 }));
856 }
857 self.assign_local_impl(var.id, impl_id)
858 }
859
860 fn assign_ty(&mut self, var: TypeVar, ty: TypeId) -> InferenceResult<TypeId> {
863 if var.inference_id != self.inference_id {
864 return Err(self.set_error(InferenceError::TypeKindMismatch {
865 ty0: TypeLongId::Var(var).intern(self.db),
866 ty1: ty,
867 }));
868 }
869 assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
870 let inference_var = InferenceVar::Type(var.id);
871 if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
872 return Err(self.set_error(InferenceError::Cycle(inference_var)));
873 }
874 if let TypeLongId::Var(other) = ty.lookup_intern(self.db) {
876 if other.inference_id == self.inference_id && other.id.0 > var.id.0 {
877 let var_ty = TypeLongId::Var(var).intern(self.db);
878 self.type_assignment.insert(other.id, var_ty);
879 return Ok(var_ty);
880 }
881 }
882 self.type_assignment.insert(var.id, ty);
883 Ok(ty)
884 }
885
886 fn assign_const(&mut self, var: ConstVar, id: ConstValueId) -> InferenceResult<ConstValueId> {
889 if var.inference_id != self.inference_id {
890 return Err(self.set_error(InferenceError::ConstKindMismatch {
891 const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
892 .intern(self.db),
893 const1: id,
894 }));
895 }
896
897 self.const_assignment.insert(var.id, id);
898 Ok(id)
899 }
900
901 fn impl_var_solution_set(
903 &mut self,
904 var: LocalImplVarId,
905 ) -> InferenceResult<SolutionSet<ImplId>> {
906 let impl_var = self.impl_var(var).clone();
907 let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
909 self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
910 let impl_var_trait_item_mappings =
911 self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
912 let solution_set = self.trait_solution_set(
913 concrete_trait_id,
914 impl_var_trait_item_mappings,
915 impl_var.lookup_context,
916 )?;
917 Ok(match solution_set {
918 SolutionSet::None => SolutionSet::None,
919 SolutionSet::Unique((canonical_impl, canonicalizer)) => {
920 SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
921 }
922 SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
923 })
924 }
925
926 pub fn trait_solution_set(
928 &mut self,
929 concrete_trait_id: ConcreteTraitId,
930 impl_var_trait_item_mappings: ImplVarTraitItemMappings,
931 mut lookup_context: ImplLookupContext,
932 ) -> InferenceResult<SolutionSet<(CanonicalImpl, CanonicalMapping)>> {
933 let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
934 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
936 enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
937
938 let generic_args = concrete_trait_id.generic_args(self.db);
940 match generic_args.first() {
941 Some(GenericArgumentId::Type(ty)) => {
942 if let TypeLongId::Var(_) = ty.lookup_intern(self.db) {
943 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
945 }
946 }
947 Some(GenericArgumentId::Impl(imp)) => {
948 if let ImplLongId::ImplVar(_) = imp.lookup_intern(self.db) {
950 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
951 }
952 }
953 Some(GenericArgumentId::Constant(const_value)) => {
954 if let ConstValue::Var(_, _) = const_value.lookup_intern(self.db) {
955 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
957 }
958 }
959 _ => {}
960 };
961 let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
962 self.db,
963 self.inference_id,
964 concrete_trait_id,
965 impl_var_trait_item_mappings,
966 );
967 let solution_set = match self.db.canonic_trait_solutions(
970 canonical_trait,
971 lookup_context,
972 (*self.data.impl_type_bounds).clone(),
973 ) {
974 Ok(solution_set) => solution_set,
975 Err(err) => return Err(self.set_error(err)),
976 };
977 match solution_set {
978 SolutionSet::None => Ok(SolutionSet::None),
979 SolutionSet::Unique(canonical_impl) => {
980 Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
981 }
982 SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
983 }
984 }
985
986 fn validate_neg_impls(
990 &mut self,
991 lookup_context: &ImplLookupContext,
992 canonical_impl: CanonicalImpl,
993 ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
994 fn validate_no_solution_set(
996 inference: &mut Inference<'_>,
997 canonical_impl: CanonicalImpl,
998 lookup_context: &ImplLookupContext,
999 negative_impls_concrete_traits: impl Iterator<Item = Maybe<ConcreteTraitId>>,
1000 ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
1001 for concrete_trait_id in negative_impls_concrete_traits {
1002 let concrete_trait_id = concrete_trait_id.map_err(|diag_added| {
1003 inference.set_error(InferenceError::Reported(diag_added))
1004 })?;
1005 for garg in concrete_trait_id.generic_args(inference.db) {
1006 let GenericArgumentId::Type(ty) = garg else {
1007 continue;
1008 };
1009 let ty = inference.rewrite(ty).no_err();
1010 if !matches!(ty.lookup_intern(inference.db), TypeLongId::Closure(_))
1019 && !ty.is_fully_concrete(inference.db)
1020 {
1021 return Ok(SolutionSet::Ambiguous(
1024 Ambiguity::NegativeImplWithUnresolvedGenericArgs {
1025 impl_id: canonical_impl.0,
1026 ty,
1027 },
1028 ));
1029 }
1030 }
1031
1032 if !matches!(
1033 inference.trait_solution_set(
1034 concrete_trait_id,
1035 ImplVarTraitItemMappings::default(),
1036 lookup_context.clone()
1037 )?,
1038 SolutionSet::None
1039 ) {
1040 return Ok(SolutionSet::None);
1042 }
1043 }
1044
1045 Ok(SolutionSet::Unique(canonical_impl))
1046 }
1047 match canonical_impl.0.lookup_intern(self.db) {
1048 ImplLongId::Concrete(concrete_impl) => {
1049 let substitution = concrete_impl
1050 .substitution(self.db)
1051 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1052 let generic_params = self
1053 .db
1054 .impl_def_generic_params(concrete_impl.impl_def_id(self.db))
1055 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1056 let concrete_traits = generic_params
1057 .iter()
1058 .filter_map(|generic_param| {
1059 try_extract_matches!(generic_param, GenericParam::NegImpl)
1060 })
1061 .map(|generic_param| {
1062 substitution
1063 .substitute(self.db, generic_param.clone())
1064 .and_then(|generic_param| generic_param.concrete_trait)
1065 });
1066 validate_no_solution_set(self, canonical_impl, lookup_context, concrete_traits)
1067 }
1068 ImplLongId::GeneratedImpl(generated_impl) => validate_no_solution_set(
1069 self,
1070 canonical_impl,
1071 lookup_context,
1072 generated_impl
1073 .lookup_intern(self.db)
1074 .generic_params
1075 .iter()
1076 .filter_map(|generic_param| {
1077 try_extract_matches!(generic_param, GenericParam::NegImpl)
1078 })
1079 .map(|generic_param| generic_param.concrete_trait),
1080 ),
1081 ImplLongId::GenericParameter(_)
1082 | ImplLongId::ImplVar(_)
1083 | ImplLongId::ImplImpl(_)
1084 | ImplLongId::SelfImpl(_) => Ok(SolutionSet::Unique(canonical_impl)),
1085 }
1086 }
1087
1088 pub fn set_error(&mut self, err: InferenceError) -> ErrorSet {
1095 if self.error_status.is_err() {
1096 return ErrorSet;
1097 }
1098 self.error_status = if let InferenceError::Reported(diag_added) = err {
1099 self.consumed_error = Some(diag_added);
1100 Err(InferenceErrorStatus::Consumed)
1101 } else {
1102 self.error = Some(err);
1103 Err(InferenceErrorStatus::Pending)
1104 };
1105 ErrorSet
1106 }
1107
1108 pub fn is_error_set(&self) -> InferenceResult<()> {
1110 if self.error_status.is_err() { Err(ErrorSet) } else { Ok(()) }
1111 }
1112
1113 pub fn consume_error_without_reporting(&mut self, err_set: ErrorSet) -> Option<InferenceError> {
1119 self.consume_error_inner(err_set, skip_diagnostic())
1120 }
1121
1122 pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1129 self.consume_error_inner(err_set, diag_added);
1130 }
1131
1132 fn consume_error_inner(
1139 &mut self,
1140 _err_set: ErrorSet,
1141 diag_added: DiagnosticAdded,
1142 ) -> Option<InferenceError> {
1143 if self.error_status != Err(InferenceErrorStatus::Pending) {
1144 return None;
1145 }
1147 self.error_status = Err(InferenceErrorStatus::Consumed);
1148 self.consumed_error = Some(diag_added);
1149 mem::take(&mut self.error)
1150 }
1151
1152 pub fn report_on_pending_error(
1158 &mut self,
1159 _err_set: ErrorSet,
1160 diagnostics: &mut SemanticDiagnostics,
1161 stable_ptr: SyntaxStablePtrId,
1162 ) -> DiagnosticAdded {
1163 let Err(state_error) = self.error_status else {
1164 panic!("report_on_pending_error should be called only on error");
1165 };
1166 match state_error {
1167 InferenceErrorStatus::Consumed => self
1168 .consumed_error
1169 .expect("consumed_error is not set although error_status is Err(Consumed)"),
1170 InferenceErrorStatus::Pending => {
1171 let diag_added = match mem::take(&mut self.error)
1172 .expect("error is not set although error_status is Err(Pending)")
1173 {
1174 InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1175 skip_diagnostic()
1180 }
1181 diag => diag.report(diagnostics, stable_ptr),
1182 };
1183
1184 self.error_status = Err(InferenceErrorStatus::Consumed);
1185 self.consumed_error = Some(diag_added);
1186 diag_added
1187 }
1188 }
1189 }
1190
1191 pub fn report_modified_if_pending(
1194 &mut self,
1195 err_set: ErrorSet,
1196 report: impl FnOnce() -> DiagnosticAdded,
1197 ) {
1198 if self.error_status == Err(InferenceErrorStatus::Pending) {
1199 self.consume_reported_error(err_set, report());
1200 }
1201 }
1202}
1203
1204impl<'a> HasDb<&'a dyn SemanticGroup> for Inference<'a> {
1205 fn get_db(&self) -> &'a dyn SemanticGroup {
1206 self.db
1207 }
1208}
1209add_basic_rewrites!(<'a>, Inference<'a>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue);
1210add_expr_rewrites!(<'a>, Inference<'a>, NoError, @exclude);
1211add_rewrite!(<'a>, Inference<'a>, NoError, Ambiguity);
1212impl SemanticRewriter<TypeId, NoError> for Inference<'_> {
1213 fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
1214 if value.is_var_free(self.db) {
1215 return Ok(RewriteResult::NoChange);
1216 }
1217 value.default_rewrite(self)
1218 }
1219}
1220impl SemanticRewriter<ImplId, NoError> for Inference<'_> {
1221 fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
1222 if value.is_var_free(self.db) {
1223 return Ok(RewriteResult::NoChange);
1224 }
1225 value.default_rewrite(self)
1226 }
1227}
1228impl SemanticRewriter<TypeLongId, NoError> for Inference<'_> {
1229 fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
1230 match value {
1231 TypeLongId::Var(var) => {
1232 if let Some(type_id) = self.type_assignment.get(&var.id) {
1233 let mut long_type_id = type_id.lookup_intern(self.db);
1234 if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1235 *self.type_assignment.get_mut(&var.id).unwrap() =
1236 long_type_id.clone().intern(self.db);
1237 }
1238 *value = long_type_id;
1239 return Ok(RewriteResult::Modified);
1240 }
1241 }
1242 TypeLongId::ImplType(impl_type_id) => {
1243 if let Some(type_id) = self.impl_type_bounds.get(&((*impl_type_id).into())) {
1244 *value = type_id.lookup_intern(self.db);
1245 self.internal_rewrite(value)?;
1246 return Ok(RewriteResult::Modified);
1247 }
1248 let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1249 let impl_id = impl_type_id.impl_id();
1250 let trait_ty = impl_type_id.ty();
1251 return Ok(match impl_id.lookup_intern(self.db) {
1252 ImplLongId::GenericParameter(_)
1253 | ImplLongId::SelfImpl(_)
1254 | ImplLongId::ImplImpl(_) => impl_type_id_rewrite_result,
1255 ImplLongId::Concrete(_) => {
1256 if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1257 impl_id, trait_ty, self.db,
1258 )) {
1259 *value = self.rewrite(ty).no_err().lookup_intern(self.db);
1260 RewriteResult::Modified
1261 } else {
1262 impl_type_id_rewrite_result
1263 }
1264 }
1265 ImplLongId::ImplVar(var) => {
1266 *value = self.rewritten_impl_type(var, trait_ty).lookup_intern(self.db);
1267 return Ok(RewriteResult::Modified);
1268 }
1269 ImplLongId::GeneratedImpl(generated) => {
1270 *value = self
1271 .rewrite(
1272 *generated
1273 .lookup_intern(self.db)
1274 .impl_items
1275 .0
1276 .get(&impl_type_id.ty())
1277 .unwrap(),
1278 )
1279 .no_err()
1280 .lookup_intern(self.db);
1281 RewriteResult::Modified
1282 }
1283 });
1284 }
1285 _ => {}
1286 }
1287 value.default_rewrite(self)
1288 }
1289}
1290impl SemanticRewriter<ConstValue, NoError> for Inference<'_> {
1291 fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
1292 match value {
1293 ConstValue::Var(var, _) => {
1294 return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1295 let mut const_value = const_value_id.lookup_intern(self.db);
1296 if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1297 *self.const_assignment.get_mut(&var.id).unwrap() =
1298 const_value.clone().intern(self.db);
1299 }
1300 *value = const_value;
1301 RewriteResult::Modified
1302 } else {
1303 RewriteResult::NoChange
1304 });
1305 }
1306 ConstValue::ImplConstant(impl_constant_id) => {
1307 let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1308 let impl_id = impl_constant_id.impl_id();
1309 let trait_constant = impl_constant_id.trait_constant_id();
1310 return Ok(match impl_id.lookup_intern(self.db) {
1311 ImplLongId::GenericParameter(_)
1312 | ImplLongId::SelfImpl(_)
1313 | ImplLongId::GeneratedImpl(_)
1314 | ImplLongId::ImplImpl(_) => impl_constant_id_rewrite_result,
1315 ImplLongId::Concrete(_) => {
1316 if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1317 ImplConstantId::new(impl_id, trait_constant, self.db),
1318 ) {
1319 *value = self.rewrite(constant).no_err().lookup_intern(self.db);
1320 RewriteResult::Modified
1321 } else {
1322 impl_constant_id_rewrite_result
1323 }
1324 }
1325 ImplLongId::ImplVar(var) => {
1326 *value = self
1327 .rewritten_impl_constant(var, trait_constant)
1328 .lookup_intern(self.db);
1329 return Ok(RewriteResult::Modified);
1330 }
1331 });
1332 }
1333 _ => {}
1334 }
1335 value.default_rewrite(self)
1336 }
1337}
1338impl SemanticRewriter<ImplLongId, NoError> for Inference<'_> {
1339 fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
1340 match value {
1341 ImplLongId::ImplVar(var) => {
1342 let long_id = var.lookup_intern(self.db);
1343 let impl_var_id = long_id.id;
1345 if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1346 let mut long_impl_id = impl_id.lookup_intern(self.db);
1347 if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1348 *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1349 long_impl_id.clone().intern(self.db);
1350 }
1351 *value = long_impl_id;
1352 return Ok(RewriteResult::Modified);
1353 }
1354 }
1355 ImplLongId::ImplImpl(impl_impl_id) => {
1356 let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1357 let impl_id = impl_impl_id.impl_id();
1358 return Ok(match impl_id.lookup_intern(self.db) {
1359 ImplLongId::GenericParameter(_)
1360 | ImplLongId::SelfImpl(_)
1361 | ImplLongId::GeneratedImpl(_)
1362 | ImplLongId::ImplImpl(_) => impl_impl_id_rewrite_result,
1363 ImplLongId::Concrete(_) => {
1364 if let Ok(imp) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1365 *value = self.rewrite(imp).no_err().lookup_intern(self.db);
1366 RewriteResult::Modified
1367 } else {
1368 impl_impl_id_rewrite_result
1369 }
1370 }
1371 ImplLongId::ImplVar(var) => {
1372 if let Ok(concrete_trait_impl) =
1373 impl_impl_id.concrete_trait_impl_id(self.db)
1374 {
1375 *value = self
1376 .rewritten_impl_impl(var, concrete_trait_impl)
1377 .lookup_intern(self.db);
1378 return Ok(RewriteResult::Modified);
1379 } else {
1380 impl_impl_id_rewrite_result
1381 }
1382 }
1383 });
1384 }
1385
1386 _ => {}
1387 }
1388 if value.is_var_free(self.db) {
1389 return Ok(RewriteResult::NoChange);
1390 }
1391 value.default_rewrite(self)
1392 }
1393}
1394
1395struct InferenceIdReplacer<'a> {
1396 db: &'a dyn SemanticGroup,
1397 from_inference_id: InferenceId,
1398 to_inference_id: InferenceId,
1399}
1400impl<'a> InferenceIdReplacer<'a> {
1401 fn new(
1402 db: &'a dyn SemanticGroup,
1403 from_inference_id: InferenceId,
1404 to_inference_id: InferenceId,
1405 ) -> Self {
1406 Self { db, from_inference_id, to_inference_id }
1407 }
1408}
1409impl<'a> HasDb<&'a dyn SemanticGroup> for InferenceIdReplacer<'a> {
1410 fn get_db(&self) -> &'a dyn SemanticGroup {
1411 self.db
1412 }
1413}
1414add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1415add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1416add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity);
1417impl SemanticRewriter<InferenceId, NoError> for InferenceIdReplacer<'_> {
1418 fn internal_rewrite(&mut self, value: &mut InferenceId) -> Result<RewriteResult, NoError> {
1419 if value == &self.from_inference_id {
1420 *value = self.to_inference_id;
1421 Ok(RewriteResult::Modified)
1422 } else {
1423 Ok(RewriteResult::NoChange)
1424 }
1425 }
1426}