1use std::collections::{BTreeMap, HashMap, VecDeque};
4use std::hash::Hash;
5use std::mem;
6use std::ops::{Deref, DerefMut};
7use std::sync::Arc;
8
9use cairo_lang_debug::DebugWithDb;
10use cairo_lang_defs::ids::{
11 ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId,
12 GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LanguageElementId,
13 LocalVarId, LookupItemId, MemberId, NamedLanguageElementId, ParamId, StructId, TraitConstantId,
14 TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
15};
16use cairo_lang_diagnostics::{DiagnosticAdded, Maybe, skip_diagnostic};
17use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
18use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
19use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
20use cairo_lang_utils::{
21 Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches,
22};
23
24use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
25use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
26use crate::db::SemanticGroup;
27use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
28use crate::expr::inference::canonic::ResultNoErrEx;
29use crate::expr::inference::conform::InferenceConform;
30use crate::expr::objects::*;
31use crate::expr::pattern::*;
32use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
33use crate::items::functions::{
34 ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
35 GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
36 ImplGenericFunctionWithBodyId,
37};
38use crate::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType};
39use crate::items::imp::{
40 GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
41 ImplLookupContext, UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
42};
43use crate::items::trt::{
44 ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId, ConcreteTraitTypeId,
45 ConcreteTraitTypeLongId,
46};
47use crate::substitution::{HasDb, RewriteResult, SemanticRewriter};
48use crate::types::{
49 ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
50 ImplTypeById, ImplTypeId,
51};
52use crate::{
53 ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
54 ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
55 FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
56 Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
57 add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
58};
59
60pub mod canonic;
61pub mod conform;
62pub mod infers;
63pub mod solver;
64
65#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
68pub struct TypeVar {
69 pub inference_id: InferenceId,
70 pub id: LocalTypeVarId,
71}
72
73#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
76pub struct ConstVar {
77 pub inference_id: InferenceId,
78 pub id: LocalConstVarId,
79}
80
81#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
83#[debug_db(dyn SemanticGroup + 'static)]
84pub enum InferenceId {
85 LookupItemDeclaration(LookupItemId),
86 LookupItemGenerics(LookupItemId),
87 LookupItemDefinition(LookupItemId),
88 ImplDefTrait(ImplDefId),
89 ImplAliasImplDef(ImplAliasId),
90 GenericParam(GenericParamId),
91 GenericImplParamTrait(GenericParamId),
92 GlobalUseStar(GlobalUseId),
93 Canonical,
94 NoContext,
96}
97
98#[derive(Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
101#[debug_db(dyn SemanticGroup + 'static)]
102pub struct ImplVar {
103 pub inference_id: InferenceId,
104 #[dont_rewrite]
105 pub id: LocalImplVarId,
106 pub concrete_trait_id: ConcreteTraitId,
107 #[dont_rewrite]
108 pub lookup_context: ImplLookupContext,
109}
110impl ImplVar {
111 pub fn intern(&self, db: &dyn SemanticGroup) -> ImplVarId {
112 self.clone().intern(db)
113 }
114}
115
116#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
117pub struct LocalTypeVarId(pub usize);
118#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
119pub struct LocalImplVarId(pub usize);
120
121#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
122pub struct LocalConstVarId(pub usize);
123
124define_short_id!(ImplVarId, ImplVar, SemanticGroup, lookup_intern_impl_var, intern_impl_var);
125impl ImplVarId {
126 pub fn id(&self, db: &dyn SemanticGroup) -> LocalImplVarId {
127 self.lookup_intern(db).id
128 }
129 pub fn concrete_trait_id(&self, db: &dyn SemanticGroup) -> ConcreteTraitId {
130 self.lookup_intern(db).concrete_trait_id
131 }
132 pub fn lookup_context(&self, db: &dyn SemanticGroup) -> ImplLookupContext {
133 self.lookup_intern(db).lookup_context
134 }
135}
136semantic_object_for_id!(ImplVarId, lookup_intern_impl_var, intern_impl_var, ImplVar);
137
138#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject)]
139pub enum InferenceVar {
140 Type(LocalTypeVarId),
141 Const(LocalConstVarId),
142 Impl(LocalImplVarId),
143}
144
145#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb)]
147#[debug_db(dyn SemanticGroup + 'static)]
148pub enum InferenceError {
149 Reported(DiagnosticAdded),
151 Cycle(InferenceVar),
152 TypeKindMismatch {
153 ty0: TypeId,
154 ty1: TypeId,
155 },
156 ConstKindMismatch {
157 const0: ConstValueId,
158 const1: ConstValueId,
159 },
160 ImplKindMismatch {
161 impl0: ImplId,
162 impl1: ImplId,
163 },
164 GenericArgMismatch {
165 garg0: GenericArgumentId,
166 garg1: GenericArgumentId,
167 },
168 TraitMismatch {
169 trt0: TraitId,
170 trt1: TraitId,
171 },
172 ImplTypeMismatch {
173 impl_id: ImplId,
174 trait_type_id: TraitTypeId,
175 ty0: TypeId,
176 ty1: TypeId,
177 },
178 GenericFunctionMismatch {
179 func0: GenericFunctionId,
180 func1: GenericFunctionId,
181 },
182 ConstNotInferred,
183 NoImplsFound(ConcreteTraitId),
186 Ambiguity(Ambiguity),
187 TypeNotInferred(TypeId),
188}
189impl InferenceError {
190 pub fn format(&self, db: &(dyn SemanticGroup + 'static)) -> String {
191 match self {
192 InferenceError::Reported(_) => "Inference error occurred.".into(),
193 InferenceError::Cycle(_var) => "Inference cycle detected".into(),
194 InferenceError::TypeKindMismatch { ty0, ty1 } => {
195 format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
196 }
197 InferenceError::ConstKindMismatch { const0, const1 } => {
198 format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
199 }
200 InferenceError::ImplKindMismatch { impl0, impl1 } => {
201 format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
202 }
203 InferenceError::GenericArgMismatch { garg0, garg1 } => {
204 format!(
205 "Generic arg mismatch: `{:?}` and `{:?}`.",
206 garg0.debug(db),
207 garg1.debug(db)
208 )
209 }
210 InferenceError::TraitMismatch { trt0, trt1 } => {
211 format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
212 }
213 InferenceError::ConstNotInferred => "Failed to infer constant.".into(),
214 InferenceError::NoImplsFound(concrete_trait_id) => {
215 let info = db.core_info();
216 let trait_id = concrete_trait_id.trait_id(db);
217 if trait_id == info.numeric_literal_trt {
218 let generic_type = extract_matches!(
219 concrete_trait_id.generic_args(db)[0],
220 GenericArgumentId::Type
221 );
222 return format!(
223 "Mismatched types. The type `{:?}` cannot be created from a numeric \
224 literal.",
225 generic_type.debug(db)
226 );
227 } else if trait_id == info.string_literal_trt {
228 let generic_type = extract_matches!(
229 concrete_trait_id.generic_args(db)[0],
230 GenericArgumentId::Type
231 );
232 return format!(
233 "Mismatched types. The type `{:?}` cannot be created from a string \
234 literal.",
235 generic_type.debug(db)
236 );
237 }
238 format!(
239 "Trait has no implementation in context: {:?}.",
240 concrete_trait_id.debug(db)
241 )
242 }
243 InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
244 InferenceError::TypeNotInferred(ty) => {
245 format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
246 }
247 InferenceError::GenericFunctionMismatch { func0, func1 } => {
248 format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
249 }
250 InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 } => {
251 format!(
252 "`{}::{}` type mismatch: `{:?}` and `{:?}`.",
253 impl_id.format(db),
254 trait_type_id.name(db),
255 ty0.debug(db),
256 ty1.debug(db)
257 )
258 }
259 }
260 }
261}
262
263impl InferenceError {
264 pub fn report(
265 &self,
266 diagnostics: &mut SemanticDiagnostics,
267 stable_ptr: SyntaxStablePtrId,
268 ) -> DiagnosticAdded {
269 match self {
270 InferenceError::Reported(diagnostic_added) => *diagnostic_added,
271 _ => diagnostics
272 .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
273 }
274 }
275}
276
277#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
282pub struct ErrorSet;
283
284pub type InferenceResult<T> = Result<T, ErrorSet>;
285
286#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
287pub enum InferenceErrorStatus {
288 Pending,
289 Consumed,
290}
291
292#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, SemanticObject)]
294pub struct ImplVarTraitItemMappings {
295 types: OrderedHashMap<TraitTypeId, TypeId>,
297 constants: OrderedHashMap<TraitConstantId, ConstValueId>,
299 impls: OrderedHashMap<TraitImplId, ImplId>,
301}
302
303#[derive(Debug, DebugWithDb, PartialEq, Eq)]
305#[debug_db(dyn SemanticGroup + 'static)]
306pub struct InferenceData {
307 pub inference_id: InferenceId,
308 pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId>,
310 pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId>,
312 pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId>,
314 pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings>,
317 pub type_vars: Vec<TypeVar>,
319 pub const_vars: Vec<ConstVar>,
321 pub impl_vars: Vec<ImplVar>,
323 pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId>,
325 pending: VecDeque<LocalImplVarId>,
327 refuted: Vec<LocalImplVarId>,
329 solved: Vec<LocalImplVarId>,
331 ambiguous: Vec<(LocalImplVarId, Ambiguity)>,
333 pub impl_type_bounds: Arc<BTreeMap<ImplTypeById, TypeId>>,
335
336 pub error_status: Result<(), InferenceErrorStatus>,
339 error: Option<InferenceError>,
341 consumed_error: Option<DiagnosticAdded>,
343}
344impl InferenceData {
345 pub fn new(inference_id: InferenceId) -> Self {
346 Self {
347 inference_id,
348 type_assignment: OrderedHashMap::default(),
349 impl_assignment: OrderedHashMap::default(),
350 const_assignment: OrderedHashMap::default(),
351 impl_vars_trait_item_mappings: HashMap::new(),
352 type_vars: Vec::new(),
353 impl_vars: Vec::new(),
354 const_vars: Vec::new(),
355 stable_ptrs: HashMap::new(),
356 pending: VecDeque::new(),
357 refuted: Vec::new(),
358 solved: Vec::new(),
359 ambiguous: Vec::new(),
360 impl_type_bounds: Default::default(),
361 error_status: Ok(()),
362 error: None,
363 consumed_error: None,
364 }
365 }
366 pub fn inference<'db, 'b: 'db>(&'db mut self, db: &'b dyn SemanticGroup) -> Inference<'db> {
367 Inference::new(db, self)
368 }
369 pub fn clone_with_inference_id(
370 &self,
371 db: &dyn SemanticGroup,
372 inference_id: InferenceId,
373 ) -> InferenceData {
374 let mut inference_id_replacer =
375 InferenceIdReplacer::new(db, self.inference_id, inference_id);
376 Self {
377 inference_id,
378 type_assignment: self
379 .type_assignment
380 .iter()
381 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
382 .collect(),
383 const_assignment: self
384 .const_assignment
385 .iter()
386 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
387 .collect(),
388 impl_assignment: self
389 .impl_assignment
390 .iter()
391 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
392 .collect(),
393 impl_vars_trait_item_mappings: self
394 .impl_vars_trait_item_mappings
395 .iter()
396 .map(|(k, mappings)| {
397 (
398 *k,
399 ImplVarTraitItemMappings {
400 types: mappings
401 .types
402 .iter()
403 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
404 .collect(),
405 constants: mappings
406 .constants
407 .iter()
408 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
409 .collect(),
410 impls: mappings
411 .impls
412 .iter()
413 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
414 .collect(),
415 },
416 )
417 })
418 .collect(),
419 type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
420 const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
421 impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
422 stable_ptrs: self.stable_ptrs.clone(),
423 pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
424 refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
425 solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
426 ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
427 impl_type_bounds: self.impl_type_bounds.clone(),
429
430 error_status: self.error_status,
431 error: self.error.clone(),
432 consumed_error: self.consumed_error,
433 }
434 }
435 pub fn temporary_clone(&self) -> InferenceData {
436 Self {
437 inference_id: self.inference_id,
438 type_assignment: self.type_assignment.clone(),
439 const_assignment: self.const_assignment.clone(),
440 impl_assignment: self.impl_assignment.clone(),
441 impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
442 type_vars: self.type_vars.clone(),
443 const_vars: self.const_vars.clone(),
444 impl_vars: self.impl_vars.clone(),
445 stable_ptrs: self.stable_ptrs.clone(),
446 pending: self.pending.clone(),
447 refuted: self.refuted.clone(),
448 solved: self.solved.clone(),
449 ambiguous: self.ambiguous.clone(),
450 impl_type_bounds: self.impl_type_bounds.clone(),
451 error_status: self.error_status,
452 error: self.error.clone(),
453 consumed_error: self.consumed_error,
454 }
455 }
456}
457
458pub struct Inference<'db> {
460 db: &'db dyn SemanticGroup,
461 pub data: &'db mut InferenceData,
462}
463
464impl Deref for Inference<'_> {
465 type Target = InferenceData;
466
467 fn deref(&self) -> &Self::Target {
468 self.data
469 }
470}
471impl DerefMut for Inference<'_> {
472 fn deref_mut(&mut self) -> &mut Self::Target {
473 self.data
474 }
475}
476
477impl std::fmt::Debug for Inference<'_> {
478 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
479 let x = self.data.debug(self.db.elongate());
480 write!(f, "{x:?}")
481 }
482}
483
484impl<'db> Inference<'db> {
485 fn new(db: &'db dyn SemanticGroup, data: &'db mut InferenceData) -> Self {
486 Self { db, data }
487 }
488
489 fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar {
491 &self.impl_vars[var_id.0]
492 }
493
494 pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId> {
496 self.impl_assignment.get(&var_id).copied()
497 }
498
499 fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId> {
501 self.type_assignment.get(&var_id).copied()
502 }
503
504 pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeId {
507 let var = self.new_type_var_raw(stable_ptr);
508
509 TypeLongId::Var(var).intern(self.db)
510 }
511
512 pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeVar {
515 let var =
516 TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
517 if let Some(stable_ptr) = stable_ptr {
518 self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
519 }
520 self.type_vars.push(var);
521 var
522 }
523
524 pub fn set_impl_type_bounds(&mut self, impl_type_bounds: OrderedHashMap<ImplTypeId, TypeId>) {
527 let impl_type_bounds_finalized = impl_type_bounds
528 .iter()
529 .filter_map(|(impl_type, ty)| {
530 let rewritten_type = self.rewrite(ty.lookup_intern(self.db)).no_err();
531 if !matches!(rewritten_type, TypeLongId::Var(_)) {
532 return Some(((*impl_type).into(), rewritten_type.intern(self.db)));
533 }
534 self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
537 None
538 })
539 .collect();
540
541 self.data.impl_type_bounds = Arc::new(impl_type_bounds_finalized);
542 }
543
544 pub fn new_const_var(
547 &mut self,
548 stable_ptr: Option<SyntaxStablePtrId>,
549 ty: TypeId,
550 ) -> ConstValueId {
551 let var = self.new_const_var_raw(stable_ptr);
552 ConstValue::Var(var, ty).intern(self.db)
553 }
554
555 pub fn new_const_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> ConstVar {
558 let var = ConstVar {
559 inference_id: self.inference_id,
560 id: LocalConstVarId(self.const_vars.len()),
561 };
562 if let Some(stable_ptr) = stable_ptr {
563 self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
564 }
565 self.const_vars.push(var);
566 var
567 }
568
569 pub fn new_impl_var(
572 &mut self,
573 concrete_trait_id: ConcreteTraitId,
574 stable_ptr: Option<SyntaxStablePtrId>,
575 lookup_context: ImplLookupContext,
576 ) -> ImplId {
577 let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
578 ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
579 }
580
581 fn new_impl_var_raw(
584 &mut self,
585 lookup_context: ImplLookupContext,
586 concrete_trait_id: ConcreteTraitId,
587 stable_ptr: Option<SyntaxStablePtrId>,
588 ) -> LocalImplVarId {
589 let mut lookup_context = lookup_context;
590 lookup_context.insert_module(concrete_trait_id.trait_id(self.db).module_file_id(self.db).0);
591
592 let id = LocalImplVarId(self.impl_vars.len());
593 if let Some(stable_ptr) = stable_ptr {
594 self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
595 }
596 let var =
597 ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
598 self.impl_vars.push(var);
599 self.pending.push_back(id);
600 id
601 }
602
603 pub fn solve(&mut self) -> InferenceResult<()> {
608 self.solve_ex().map_err(|(err_set, _)| err_set)
609 }
610
611 fn solve_ex(&mut self) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
613 let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
614 self.pending.extend(ambiguous.map(|(var, _)| var));
615 while let Some(var) = self.pending.pop_front() {
616 self.solve_single_pending(var).map_err(|err_set| {
618 (err_set, self.stable_ptrs.get(&InferenceVar::Impl(var)).copied())
619 })?;
620 }
621 Ok(())
622 }
623
624 fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
625 if self.impl_assignment.contains_key(&var) {
626 return Ok(());
627 }
628 let solution = match self.impl_var_solution_set(var)? {
629 SolutionSet::None => {
630 self.refuted.push(var);
631 return Ok(());
632 }
633 SolutionSet::Ambiguous(ambiguity) => {
634 self.ambiguous.push((var, ambiguity));
635 return Ok(());
636 }
637 SolutionSet::Unique(solution) => solution,
638 };
639
640 self.assign_local_impl(var, solution)?;
642
643 self.solved.push(var);
645 let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
646 self.pending.extend(ambiguous.map(|(var, _)| var));
647
648 Ok(())
649 }
650
651 pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<()>> {
654 self.solve()?;
655 if !self.refuted.is_empty() {
656 return Ok(SolutionSet::None);
657 }
658 if let Some((_, ambiguity)) = self.ambiguous.first() {
659 return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
660 }
661 assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
662 Ok(SolutionSet::Unique(()))
663 }
664
665 pub fn finalize_without_reporting(
668 &mut self,
669 ) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
670 if self.error_status.is_err() {
671 return Err((ErrorSet, None));
673 }
674 let info = self.db.core_info();
675 let numeric_trait_id = info.numeric_literal_trt;
676 let felt_ty = info.felt252;
677
678 loop {
680 let mut changed = false;
681 self.solve_ex()?;
682 for (var, _) in self.ambiguous.clone() {
683 let impl_var = self.impl_var(var).clone();
684 if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
685 continue;
686 }
687 let ty = extract_matches!(
689 impl_var.concrete_trait_id.generic_args(self.db)[0],
690 GenericArgumentId::Type
691 );
692 if self.rewrite(ty).no_err() == felt_ty {
693 continue;
694 }
695 self.conform_ty(ty, felt_ty).map_err(|err_set| {
696 (err_set, self.stable_ptrs.get(&InferenceVar::Impl(impl_var.id)).copied())
697 })?;
698 changed = true;
699 break;
700 }
701 if !changed {
702 break;
703 }
704 }
705 assert!(
706 self.pending.is_empty(),
707 "pending should all be solved by this point. Guaranteed by solve()."
708 );
709
710 let Some((var, err)) = self.first_undetermined_variable() else {
711 return Ok(());
712 };
713 Err((self.set_error(err), self.stable_ptrs.get(&var).copied()))
714 }
715
716 pub fn finalize(
720 &mut self,
721 diagnostics: &mut SemanticDiagnostics,
722 stable_ptr: SyntaxStablePtrId,
723 ) {
724 if let Err((err_set, err_stable_ptr)) = self.finalize_without_reporting() {
725 let diag = self.report_on_pending_error(
726 err_set,
727 diagnostics,
728 err_stable_ptr.unwrap_or(stable_ptr),
729 );
730
731 let ty_missing = TypeId::missing(self.db, diag);
732 for var in &self.data.type_vars {
733 self.data.type_assignment.entry(var.id).or_insert(ty_missing);
734 }
735 }
736 }
737
738 fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError)> {
742 if let Some(var) = self.refuted.first().copied() {
743 let impl_var = self.impl_var(var).clone();
744 let concrete_trait_id = impl_var.concrete_trait_id;
745 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
746 return Some((
747 InferenceVar::Impl(var),
748 InferenceError::NoImplsFound(concrete_trait_id),
749 ));
750 }
751 let mut fallback_ret = None;
752 if let Some((var, ambiguity)) = self.ambiguous.first() {
753 let ret =
755 Some((InferenceVar::Impl(*var), InferenceError::Ambiguity(ambiguity.clone())));
756 if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
757 return ret;
758 } else {
759 fallback_ret = ret;
760 }
761 }
762 for (id, var) in self.type_vars.iter().enumerate() {
763 if self.type_assignment(LocalTypeVarId(id)).is_none() {
764 let ty = TypeLongId::Var(*var).intern(self.db);
765 return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
766 }
767 }
768 for (id, var) in self.const_vars.iter().enumerate() {
769 if !self.const_assignment.contains_key(&LocalConstVarId(id)) {
770 let infernence_var = InferenceVar::Const(var.id);
771 return Some((infernence_var, InferenceError::ConstNotInferred));
772 }
773 }
774 fallback_ret
775 }
776
777 fn assign_local_impl(
779 &mut self,
780 var: LocalImplVarId,
781 impl_id: ImplId,
782 ) -> InferenceResult<ImplId> {
783 let concrete_trait = impl_id
784 .concrete_trait(self.db)
785 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
786 self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
787 if let Some(other_impl) = self.impl_assignment(var) {
788 return self.conform_impl(impl_id, other_impl);
789 }
790 if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
791 {
792 return Err(self.set_error(InferenceError::Cycle(InferenceVar::Impl(var))));
793 }
794 self.impl_assignment.insert(var, impl_id);
795 if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
796 for (trait_type_id, ty) in mappings.types {
797 let impl_ty = self
798 .db
799 .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_type_id, self.db))
800 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
801 if let Err(err_set) = self.conform_ty(ty, impl_ty) {
802 let ty0 = self.rewrite(ty).no_err();
804 let ty1 = self.rewrite(impl_ty).no_err();
805
806 self.error =
807 Some(InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 });
808 return Err(err_set);
809 }
810 }
811 for (trait_constant, constant_id) in mappings.constants {
812 let concrete_impl_constant = self
813 .db
814 .impl_constant_concrete_implized_value(ImplConstantId::new(
815 impl_id,
816 trait_constant,
817 self.db,
818 ))
819 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
820 self.conform_const(constant_id, concrete_impl_constant)?;
821 }
822 for (trait_impl, inner_impl_id) in mappings.impls {
823 let concrete_impl_impl = self
824 .db
825 .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
826 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
827 self.conform_impl(inner_impl_id, concrete_impl_impl)?;
828 }
829 }
830 Ok(impl_id)
831 }
832
833 fn assign_impl(&mut self, var_id: ImplVarId, impl_id: ImplId) -> InferenceResult<ImplId> {
835 let var = var_id.lookup_intern(self.db);
836 if var.inference_id != self.inference_id {
837 return Err(self.set_error(InferenceError::ImplKindMismatch {
838 impl0: ImplLongId::ImplVar(var_id).intern(self.db),
839 impl1: impl_id,
840 }));
841 }
842 self.assign_local_impl(var.id, impl_id)
843 }
844
845 fn assign_ty(&mut self, var: TypeVar, ty: TypeId) -> InferenceResult<TypeId> {
848 if var.inference_id != self.inference_id {
849 return Err(self.set_error(InferenceError::TypeKindMismatch {
850 ty0: TypeLongId::Var(var).intern(self.db),
851 ty1: ty,
852 }));
853 }
854 assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
855 let inference_var = InferenceVar::Type(var.id);
856 if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
857 return Err(self.set_error(InferenceError::Cycle(inference_var)));
858 }
859 if let TypeLongId::Var(other) = ty.lookup_intern(self.db) {
861 if other.inference_id == self.inference_id && other.id.0 > var.id.0 {
862 let var_ty = TypeLongId::Var(var).intern(self.db);
863 self.type_assignment.insert(other.id, var_ty);
864 return Ok(var_ty);
865 }
866 }
867 self.type_assignment.insert(var.id, ty);
868 Ok(ty)
869 }
870
871 fn assign_const(&mut self, var: ConstVar, id: ConstValueId) -> InferenceResult<ConstValueId> {
874 if var.inference_id != self.inference_id {
875 return Err(self.set_error(InferenceError::ConstKindMismatch {
876 const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
877 .intern(self.db),
878 const1: id,
879 }));
880 }
881
882 self.const_assignment.insert(var.id, id);
883 Ok(id)
884 }
885
886 fn impl_var_solution_set(
888 &mut self,
889 var: LocalImplVarId,
890 ) -> InferenceResult<SolutionSet<ImplId>> {
891 let impl_var = self.impl_var(var).clone();
892 let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
894 self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
895 let impl_var_trait_item_mappings =
896 self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
897 let solution_set = self.trait_solution_set(
898 concrete_trait_id,
899 impl_var_trait_item_mappings,
900 impl_var.lookup_context,
901 )?;
902 Ok(match solution_set {
903 SolutionSet::None => SolutionSet::None,
904 SolutionSet::Unique((canonical_impl, canonicalizer)) => {
905 SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
906 }
907 SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
908 })
909 }
910
911 pub fn trait_solution_set(
913 &mut self,
914 concrete_trait_id: ConcreteTraitId,
915 impl_var_trait_item_mappings: ImplVarTraitItemMappings,
916 mut lookup_context: ImplLookupContext,
917 ) -> InferenceResult<SolutionSet<(CanonicalImpl, CanonicalMapping)>> {
918 let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
919 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
921 enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
922
923 let generic_args = concrete_trait_id.generic_args(self.db);
925 match generic_args.first() {
926 Some(GenericArgumentId::Type(ty)) => {
927 if let TypeLongId::Var(_) = ty.lookup_intern(self.db) {
928 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
930 }
931 }
932 Some(GenericArgumentId::Impl(imp)) => {
933 if let ImplLongId::ImplVar(_) = imp.lookup_intern(self.db) {
935 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
936 }
937 }
938 Some(GenericArgumentId::Constant(const_value)) => {
939 if let ConstValue::Var(_, _) = const_value.lookup_intern(self.db) {
940 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
942 }
943 }
944 _ => {}
945 };
946 let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
947 self.db,
948 self.inference_id,
949 concrete_trait_id,
950 impl_var_trait_item_mappings,
951 );
952 let solution_set = match self.db.canonic_trait_solutions(
955 canonical_trait,
956 lookup_context,
957 (*self.data.impl_type_bounds).clone(),
958 ) {
959 Ok(solution_set) => solution_set,
960 Err(err) => return Err(self.set_error(err)),
961 };
962 match solution_set {
963 SolutionSet::None => Ok(SolutionSet::None),
964 SolutionSet::Unique(canonical_impl) => {
965 Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
966 }
967 SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
968 }
969 }
970
971 fn validate_neg_impls(
975 &mut self,
976 lookup_context: &ImplLookupContext,
977 canonical_impl: CanonicalImpl,
978 ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
979 fn validate_no_solution_set(
981 inference: &mut Inference<'_>,
982 canonical_impl: CanonicalImpl,
983 lookup_context: &ImplLookupContext,
984 negative_impls_concrete_traits: impl Iterator<Item = Maybe<ConcreteTraitId>>,
985 ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
986 for concrete_trait_id in negative_impls_concrete_traits {
987 let concrete_trait_id = concrete_trait_id.map_err(|diag_added| {
988 inference.set_error(InferenceError::Reported(diag_added))
989 })?;
990 for garg in concrete_trait_id.generic_args(inference.db) {
991 let GenericArgumentId::Type(ty) = garg else {
992 continue;
993 };
994 let ty = inference.rewrite(ty).no_err();
995 if !matches!(ty.lookup_intern(inference.db), TypeLongId::Closure(_))
1004 && !ty.is_fully_concrete(inference.db)
1005 {
1006 return Ok(SolutionSet::Ambiguous(
1009 Ambiguity::NegativeImplWithUnresolvedGenericArgs {
1010 impl_id: canonical_impl.0,
1011 ty,
1012 },
1013 ));
1014 }
1015 }
1016
1017 if !matches!(
1018 inference.trait_solution_set(
1019 concrete_trait_id,
1020 ImplVarTraitItemMappings::default(),
1021 lookup_context.clone()
1022 )?,
1023 SolutionSet::None
1024 ) {
1025 return Ok(SolutionSet::None);
1027 }
1028 }
1029
1030 Ok(SolutionSet::Unique(canonical_impl))
1031 }
1032 match canonical_impl.0.lookup_intern(self.db) {
1033 ImplLongId::Concrete(concrete_impl) => {
1034 let substitution = concrete_impl
1035 .substitution(self.db)
1036 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1037 let generic_params = self
1038 .db
1039 .impl_def_generic_params(concrete_impl.impl_def_id(self.db))
1040 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1041 let concrete_traits = generic_params
1042 .iter()
1043 .filter_map(|generic_param| {
1044 try_extract_matches!(generic_param, GenericParam::NegImpl)
1045 })
1046 .map(|generic_param| {
1047 substitution
1048 .substitute(self.db, generic_param.clone())
1049 .and_then(|generic_param| generic_param.concrete_trait)
1050 });
1051 validate_no_solution_set(self, canonical_impl, lookup_context, concrete_traits)
1052 }
1053 ImplLongId::GeneratedImpl(generated_impl) => validate_no_solution_set(
1054 self,
1055 canonical_impl,
1056 lookup_context,
1057 generated_impl
1058 .lookup_intern(self.db)
1059 .generic_params
1060 .iter()
1061 .filter_map(|generic_param| {
1062 try_extract_matches!(generic_param, GenericParam::NegImpl)
1063 })
1064 .map(|generic_param| generic_param.concrete_trait),
1065 ),
1066 ImplLongId::GenericParameter(_)
1067 | ImplLongId::ImplVar(_)
1068 | ImplLongId::ImplImpl(_)
1069 | ImplLongId::SelfImpl(_) => Ok(SolutionSet::Unique(canonical_impl)),
1070 }
1071 }
1072
1073 pub fn set_error(&mut self, err: InferenceError) -> ErrorSet {
1080 if self.error_status.is_err() {
1081 return ErrorSet;
1082 }
1083 self.error_status = if let InferenceError::Reported(diag_added) = err {
1084 self.consumed_error = Some(diag_added);
1085 Err(InferenceErrorStatus::Consumed)
1086 } else {
1087 self.error = Some(err);
1088 Err(InferenceErrorStatus::Pending)
1089 };
1090 ErrorSet
1091 }
1092
1093 pub fn is_error_set(&self) -> InferenceResult<()> {
1095 if self.error_status.is_err() { Err(ErrorSet) } else { Ok(()) }
1096 }
1097
1098 pub fn consume_error_without_reporting(&mut self, err_set: ErrorSet) -> Option<InferenceError> {
1104 self.consume_error_inner(err_set, skip_diagnostic())
1105 }
1106
1107 pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1114 self.consume_error_inner(err_set, diag_added);
1115 }
1116
1117 fn consume_error_inner(
1124 &mut self,
1125 _err_set: ErrorSet,
1126 diag_added: DiagnosticAdded,
1127 ) -> Option<InferenceError> {
1128 if self.error_status != Err(InferenceErrorStatus::Pending) {
1129 return None;
1130 }
1132 self.error_status = Err(InferenceErrorStatus::Consumed);
1133 self.consumed_error = Some(diag_added);
1134 self.error.take()
1135 }
1136
1137 pub fn report_on_pending_error(
1143 &mut self,
1144 _err_set: ErrorSet,
1145 diagnostics: &mut SemanticDiagnostics,
1146 stable_ptr: SyntaxStablePtrId,
1147 ) -> DiagnosticAdded {
1148 let Err(state_error) = self.error_status else {
1149 panic!("report_on_pending_error should be called only on error");
1150 };
1151 match state_error {
1152 InferenceErrorStatus::Consumed => self
1153 .consumed_error
1154 .expect("consumed_error is not set although error_status is Err(Consumed)"),
1155 InferenceErrorStatus::Pending => {
1156 let diag_added = match mem::take(&mut self.error)
1157 .expect("error is not set although error_status is Err(Pending)")
1158 {
1159 InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1160 skip_diagnostic()
1165 }
1166 diag => diag.report(diagnostics, stable_ptr),
1167 };
1168
1169 self.error_status = Err(InferenceErrorStatus::Consumed);
1170 self.consumed_error = Some(diag_added);
1171 diag_added
1172 }
1173 }
1174 }
1175
1176 pub fn report_modified_if_pending(
1179 &mut self,
1180 err_set: ErrorSet,
1181 report: impl FnOnce() -> DiagnosticAdded,
1182 ) {
1183 if self.error_status == Err(InferenceErrorStatus::Pending) {
1184 self.consume_reported_error(err_set, report());
1185 }
1186 }
1187}
1188
1189impl<'a> HasDb<&'a dyn SemanticGroup> for Inference<'a> {
1190 fn get_db(&self) -> &'a dyn SemanticGroup {
1191 self.db
1192 }
1193}
1194add_basic_rewrites!(<'a>, Inference<'a>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue);
1195add_expr_rewrites!(<'a>, Inference<'a>, NoError, @exclude);
1196add_rewrite!(<'a>, Inference<'a>, NoError, Ambiguity);
1197impl SemanticRewriter<TypeId, NoError> for Inference<'_> {
1198 fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
1199 if value.is_var_free(self.db) {
1200 return Ok(RewriteResult::NoChange);
1201 }
1202 value.default_rewrite(self)
1203 }
1204}
1205impl SemanticRewriter<ImplId, NoError> for Inference<'_> {
1206 fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
1207 if value.is_var_free(self.db) {
1208 return Ok(RewriteResult::NoChange);
1209 }
1210 value.default_rewrite(self)
1211 }
1212}
1213impl SemanticRewriter<TypeLongId, NoError> for Inference<'_> {
1214 fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
1215 match value {
1216 TypeLongId::Var(var) => {
1217 if let Some(type_id) = self.type_assignment.get(&var.id) {
1218 let mut long_type_id = type_id.lookup_intern(self.db);
1219 if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1220 *self.type_assignment.get_mut(&var.id).unwrap() =
1221 long_type_id.clone().intern(self.db);
1222 }
1223 *value = long_type_id;
1224 return Ok(RewriteResult::Modified);
1225 }
1226 }
1227 TypeLongId::ImplType(impl_type_id) => {
1228 if let Some(type_id) = self.impl_type_bounds.get(&((*impl_type_id).into())) {
1229 *value = type_id.lookup_intern(self.db);
1230 self.internal_rewrite(value)?;
1231 return Ok(RewriteResult::Modified);
1232 }
1233 let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1234 let impl_id = impl_type_id.impl_id();
1235 let trait_ty = impl_type_id.ty();
1236 return Ok(match impl_id.lookup_intern(self.db) {
1237 ImplLongId::GenericParameter(_)
1238 | ImplLongId::SelfImpl(_)
1239 | ImplLongId::ImplImpl(_) => impl_type_id_rewrite_result,
1240 ImplLongId::Concrete(_) => {
1241 if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1242 impl_id, trait_ty, self.db,
1243 )) {
1244 *value = self.rewrite(ty).no_err().lookup_intern(self.db);
1245 RewriteResult::Modified
1246 } else {
1247 impl_type_id_rewrite_result
1248 }
1249 }
1250 ImplLongId::ImplVar(var) => {
1251 *value = self.rewritten_impl_type(var, trait_ty).lookup_intern(self.db);
1252 return Ok(RewriteResult::Modified);
1253 }
1254 ImplLongId::GeneratedImpl(generated) => {
1255 *value = self
1256 .rewrite(
1257 *generated
1258 .lookup_intern(self.db)
1259 .impl_items
1260 .0
1261 .get(&impl_type_id.ty())
1262 .unwrap(),
1263 )
1264 .no_err()
1265 .lookup_intern(self.db);
1266 RewriteResult::Modified
1267 }
1268 });
1269 }
1270 _ => {}
1271 }
1272 value.default_rewrite(self)
1273 }
1274}
1275impl SemanticRewriter<ConstValue, NoError> for Inference<'_> {
1276 fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
1277 match value {
1278 ConstValue::Var(var, _) => {
1279 return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1280 let mut const_value = const_value_id.lookup_intern(self.db);
1281 if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1282 *self.const_assignment.get_mut(&var.id).unwrap() =
1283 const_value.clone().intern(self.db);
1284 }
1285 *value = const_value;
1286 RewriteResult::Modified
1287 } else {
1288 RewriteResult::NoChange
1289 });
1290 }
1291 ConstValue::ImplConstant(impl_constant_id) => {
1292 let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1293 let impl_id = impl_constant_id.impl_id();
1294 let trait_constant = impl_constant_id.trait_constant_id();
1295 return Ok(match impl_id.lookup_intern(self.db) {
1296 ImplLongId::GenericParameter(_)
1297 | ImplLongId::SelfImpl(_)
1298 | ImplLongId::GeneratedImpl(_)
1299 | ImplLongId::ImplImpl(_) => impl_constant_id_rewrite_result,
1300 ImplLongId::Concrete(_) => {
1301 if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1302 ImplConstantId::new(impl_id, trait_constant, self.db),
1303 ) {
1304 *value = self.rewrite(constant).no_err().lookup_intern(self.db);
1305 RewriteResult::Modified
1306 } else {
1307 impl_constant_id_rewrite_result
1308 }
1309 }
1310 ImplLongId::ImplVar(var) => {
1311 *value = self
1312 .rewritten_impl_constant(var, trait_constant)
1313 .lookup_intern(self.db);
1314 return Ok(RewriteResult::Modified);
1315 }
1316 });
1317 }
1318 _ => {}
1319 }
1320 value.default_rewrite(self)
1321 }
1322}
1323impl SemanticRewriter<ImplLongId, NoError> for Inference<'_> {
1324 fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
1325 match value {
1326 ImplLongId::ImplVar(var) => {
1327 let long_id = var.lookup_intern(self.db);
1328 let impl_var_id = long_id.id;
1330 if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1331 let mut long_impl_id = impl_id.lookup_intern(self.db);
1332 if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1333 *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1334 long_impl_id.clone().intern(self.db);
1335 }
1336 *value = long_impl_id;
1337 return Ok(RewriteResult::Modified);
1338 }
1339 }
1340 ImplLongId::ImplImpl(impl_impl_id) => {
1341 let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1342 let impl_id = impl_impl_id.impl_id();
1343 return Ok(match impl_id.lookup_intern(self.db) {
1344 ImplLongId::GenericParameter(_)
1345 | ImplLongId::SelfImpl(_)
1346 | ImplLongId::GeneratedImpl(_)
1347 | ImplLongId::ImplImpl(_) => impl_impl_id_rewrite_result,
1348 ImplLongId::Concrete(_) => {
1349 if let Ok(imp) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1350 *value = self.rewrite(imp).no_err().lookup_intern(self.db);
1351 RewriteResult::Modified
1352 } else {
1353 impl_impl_id_rewrite_result
1354 }
1355 }
1356 ImplLongId::ImplVar(var) => {
1357 if let Ok(concrete_trait_impl) =
1358 impl_impl_id.concrete_trait_impl_id(self.db)
1359 {
1360 *value = self
1361 .rewritten_impl_impl(var, concrete_trait_impl)
1362 .lookup_intern(self.db);
1363 return Ok(RewriteResult::Modified);
1364 } else {
1365 impl_impl_id_rewrite_result
1366 }
1367 }
1368 });
1369 }
1370
1371 _ => {}
1372 }
1373 if value.is_var_free(self.db) {
1374 return Ok(RewriteResult::NoChange);
1375 }
1376 value.default_rewrite(self)
1377 }
1378}
1379
1380struct InferenceIdReplacer<'a> {
1381 db: &'a dyn SemanticGroup,
1382 from_inference_id: InferenceId,
1383 to_inference_id: InferenceId,
1384}
1385impl<'a> InferenceIdReplacer<'a> {
1386 fn new(
1387 db: &'a dyn SemanticGroup,
1388 from_inference_id: InferenceId,
1389 to_inference_id: InferenceId,
1390 ) -> Self {
1391 Self { db, from_inference_id, to_inference_id }
1392 }
1393}
1394impl<'a> HasDb<&'a dyn SemanticGroup> for InferenceIdReplacer<'a> {
1395 fn get_db(&self) -> &'a dyn SemanticGroup {
1396 self.db
1397 }
1398}
1399add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1400add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1401add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity);
1402impl SemanticRewriter<InferenceId, NoError> for InferenceIdReplacer<'_> {
1403 fn internal_rewrite(&mut self, value: &mut InferenceId) -> Result<RewriteResult, NoError> {
1404 if value == &self.from_inference_id {
1405 *value = self.to_inference_id;
1406 Ok(RewriteResult::Modified)
1407 } else {
1408 Ok(RewriteResult::NoChange)
1409 }
1410 }
1411}