1use std::collections::{HashMap, VecDeque};
4use std::hash::Hash;
5use std::mem;
6use std::ops::{Deref, DerefMut};
7
8use cairo_lang_debug::DebugWithDb;
9use cairo_lang_defs::ids::{
10 ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId,
11 GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LanguageElementId,
12 LocalVarId, LookupItemId, MemberId, ParamId, StructId, TraitConstantId, TraitFunctionId,
13 TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
14};
15use cairo_lang_diagnostics::{DiagnosticAdded, Maybe, skip_diagnostic};
16use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
17use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
18use cairo_lang_utils::ordered_hash_map::{Entry, OrderedHashMap};
19use cairo_lang_utils::{
20 Intern, LookupIntern, define_short_id, extract_matches, try_extract_matches,
21};
22
23use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
24use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
25use crate::corelib::{CoreTraitContext, core_felt252_ty, get_core_trait, numeric_literal_trait};
26use crate::db::SemanticGroup;
27use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
28use crate::expr::inference::canonic::ResultNoErrEx;
29use crate::expr::inference::conform::InferenceConform;
30use crate::expr::objects::*;
31use crate::expr::pattern::*;
32use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
33use crate::items::functions::{
34 ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
35 GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
36 ImplGenericFunctionWithBodyId,
37};
38use crate::items::generics::{GenericParamConst, GenericParamImpl, GenericParamType};
39use crate::items::imp::{
40 GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
41 ImplLookupContext, UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
42};
43use crate::items::trt::{ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId};
44use crate::substitution::{HasDb, RewriteResult, SemanticRewriter, SubstitutionRewriter};
45use crate::types::{
46 ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
47 ImplTypeId,
48};
49use crate::{
50 ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
51 ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
52 FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
53 Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
54 add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
55};
56
57pub mod canonic;
58pub mod conform;
59pub mod infers;
60pub mod solver;
61
62#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
65pub struct TypeVar {
66 pub inference_id: InferenceId,
67 pub id: LocalTypeVarId,
68}
69
70#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
73pub struct ConstVar {
74 pub inference_id: InferenceId,
75 pub id: LocalConstVarId,
76}
77
78#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
80#[debug_db(dyn SemanticGroup + 'static)]
81pub enum InferenceId {
82 LookupItemDeclaration(LookupItemId),
83 LookupItemGenerics(LookupItemId),
84 LookupItemDefinition(LookupItemId),
85 ImplDefTrait(ImplDefId),
86 ImplAliasImplDef(ImplAliasId),
87 GenericParam(GenericParamId),
88 GenericImplParamTrait(GenericParamId),
89 GlobalUseStar(GlobalUseId),
90 Canonical,
91 NoContext,
93}
94
95#[derive(Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject)]
98#[debug_db(dyn SemanticGroup + 'static)]
99pub struct ImplVar {
100 pub inference_id: InferenceId,
101 #[dont_rewrite]
102 pub id: LocalImplVarId,
103 pub concrete_trait_id: ConcreteTraitId,
104 #[dont_rewrite]
105 pub lookup_context: ImplLookupContext,
106}
107impl ImplVar {
108 pub fn intern(&self, db: &dyn SemanticGroup) -> ImplVarId {
109 self.clone().intern(db)
110 }
111}
112
113#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
114pub struct LocalTypeVarId(pub usize);
115#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
116pub struct LocalImplVarId(pub usize);
117
118#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject)]
119pub struct LocalConstVarId(pub usize);
120
121define_short_id!(ImplVarId, ImplVar, SemanticGroup, lookup_intern_impl_var, intern_impl_var);
122impl ImplVarId {
123 pub fn id(&self, db: &dyn SemanticGroup) -> LocalImplVarId {
124 self.lookup_intern(db).id
125 }
126 pub fn concrete_trait_id(&self, db: &dyn SemanticGroup) -> ConcreteTraitId {
127 self.lookup_intern(db).concrete_trait_id
128 }
129 pub fn lookup_context(&self, db: &dyn SemanticGroup) -> ImplLookupContext {
130 self.lookup_intern(db).lookup_context
131 }
132}
133semantic_object_for_id!(ImplVarId, lookup_intern_impl_var, intern_impl_var, ImplVar);
134
135#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject)]
136pub enum InferenceVar {
137 Type(LocalTypeVarId),
138 Const(LocalConstVarId),
139 Impl(LocalImplVarId),
140}
141
142#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb)]
144#[debug_db(dyn SemanticGroup + 'static)]
145pub enum InferenceError {
146 Reported(DiagnosticAdded),
148 Cycle(InferenceVar),
149 TypeKindMismatch {
150 ty0: TypeId,
151 ty1: TypeId,
152 },
153 ConstKindMismatch {
154 const0: ConstValueId,
155 const1: ConstValueId,
156 },
157 ImplKindMismatch {
158 impl0: ImplId,
159 impl1: ImplId,
160 },
161 GenericArgMismatch {
162 garg0: GenericArgumentId,
163 garg1: GenericArgumentId,
164 },
165 TraitMismatch {
166 trt0: TraitId,
167 trt1: TraitId,
168 },
169 GenericFunctionMismatch {
170 func0: GenericFunctionId,
171 func1: GenericFunctionId,
172 },
173 ConstInferenceNotSupported,
174
175 NoImplsFound(ConcreteTraitId),
178 Ambiguity(Ambiguity),
179 TypeNotInferred(TypeId),
180}
181impl InferenceError {
182 pub fn format(&self, db: &(dyn SemanticGroup + 'static)) -> String {
183 match self {
184 InferenceError::Reported(_) => "Inference error occurred.".into(),
185 InferenceError::Cycle(_var) => "Inference cycle detected".into(),
186 InferenceError::TypeKindMismatch { ty0, ty1 } => {
187 format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
188 }
189 InferenceError::ConstKindMismatch { const0, const1 } => {
190 format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
191 }
192 InferenceError::ImplKindMismatch { impl0, impl1 } => {
193 format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
194 }
195 InferenceError::GenericArgMismatch { garg0, garg1 } => {
196 format!(
197 "Generic arg mismatch: `{:?}` and `{:?}`.",
198 garg0.debug(db),
199 garg1.debug(db)
200 )
201 }
202 InferenceError::TraitMismatch { trt0, trt1 } => {
203 format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
204 }
205 InferenceError::ConstInferenceNotSupported => {
206 "Const generic inference not yet supported.".into()
207 }
208 InferenceError::NoImplsFound(concrete_trait_id) => {
209 let trait_id = concrete_trait_id.trait_id(db);
210 if trait_id == numeric_literal_trait(db) {
211 let generic_type = extract_matches!(
212 concrete_trait_id.generic_args(db)[0],
213 GenericArgumentId::Type
214 );
215 return format!(
216 "Mismatched types. The type `{:?}` cannot be created from a numeric \
217 literal.",
218 generic_type.debug(db)
219 );
220 } else if trait_id
221 == get_core_trait(db, CoreTraitContext::TopLevel, "StringLiteral".into())
222 {
223 let generic_type = extract_matches!(
224 concrete_trait_id.generic_args(db)[0],
225 GenericArgumentId::Type
226 );
227 return format!(
228 "Mismatched types. The type `{:?}` cannot be created from a string \
229 literal.",
230 generic_type.debug(db)
231 );
232 }
233 format!(
234 "Trait has no implementation in context: {:?}.",
235 concrete_trait_id.debug(db)
236 )
237 }
238 InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
239 InferenceError::TypeNotInferred(ty) => {
240 format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
241 }
242 InferenceError::GenericFunctionMismatch { func0, func1 } => {
243 format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
244 }
245 }
246 }
247}
248
249impl InferenceError {
250 pub fn report(
251 &self,
252 diagnostics: &mut SemanticDiagnostics,
253 stable_ptr: SyntaxStablePtrId,
254 ) -> DiagnosticAdded {
255 match self {
256 InferenceError::Reported(diagnostic_added) => *diagnostic_added,
257 _ => diagnostics
258 .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
259 }
260 }
261}
262
263#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
268pub struct ErrorSet;
269
270pub type InferenceResult<T> = Result<T, ErrorSet>;
271
272#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
273pub enum InferenceErrorStatus {
274 Pending,
275 Consumed,
276}
277
278#[derive(Debug, Default, PartialEq, Eq, Clone, SemanticObject)]
280pub struct ImplVarTraitItemMappings {
281 types: OrderedHashMap<TraitTypeId, TypeId>,
283 constants: OrderedHashMap<TraitConstantId, ConstValueId>,
285 impls: OrderedHashMap<TraitImplId, ImplId>,
287}
288impl Hash for ImplVarTraitItemMappings {
289 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
290 self.types.iter().for_each(|(trait_type_id, type_id)| {
291 trait_type_id.hash(state);
292 type_id.hash(state);
293 });
294 self.constants.iter().for_each(|(trait_const_id, const_id)| {
295 trait_const_id.hash(state);
296 const_id.hash(state);
297 });
298 self.impls.iter().for_each(|(trait_impl_id, impl_id)| {
299 trait_impl_id.hash(state);
300 impl_id.hash(state);
301 });
302 }
303}
304
305#[derive(Debug, DebugWithDb, PartialEq, Eq)]
307#[debug_db(dyn SemanticGroup + 'static)]
308pub struct InferenceData {
309 pub inference_id: InferenceId,
310 pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId>,
312 pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId>,
314 pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId>,
316 pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings>,
319 pub type_vars: Vec<TypeVar>,
321 pub const_vars: Vec<ConstVar>,
323 pub impl_vars: Vec<ImplVar>,
325 pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId>,
327 pending: VecDeque<LocalImplVarId>,
329 refuted: Vec<LocalImplVarId>,
331 solved: Vec<LocalImplVarId>,
333 ambiguous: Vec<(LocalImplVarId, Ambiguity)>,
335 pub impl_type_bounds: OrderedHashMap<ImplTypeId, TypeId>,
337
338 pub error_status: Result<(), InferenceErrorStatus>,
341 error: Option<InferenceError>,
343 consumed_error: Option<DiagnosticAdded>,
345}
346impl InferenceData {
347 pub fn new(inference_id: InferenceId) -> Self {
348 Self {
349 inference_id,
350 type_assignment: OrderedHashMap::default(),
351 impl_assignment: OrderedHashMap::default(),
352 const_assignment: OrderedHashMap::default(),
353 impl_vars_trait_item_mappings: HashMap::new(),
354 type_vars: Vec::new(),
355 impl_vars: Vec::new(),
356 const_vars: Vec::new(),
357 stable_ptrs: HashMap::new(),
358 pending: VecDeque::new(),
359 refuted: Vec::new(),
360 solved: Vec::new(),
361 ambiguous: Vec::new(),
362 impl_type_bounds: OrderedHashMap::default(),
363 error_status: Ok(()),
364 error: None,
365 consumed_error: None,
366 }
367 }
368 pub fn inference<'db, 'b: 'db>(&'db mut self, db: &'b dyn SemanticGroup) -> Inference<'db> {
369 Inference::new(db, self)
370 }
371 pub fn clone_with_inference_id(
372 &self,
373 db: &dyn SemanticGroup,
374 inference_id: InferenceId,
375 ) -> InferenceData {
376 let mut inference_id_replacer =
377 InferenceIdReplacer::new(db, self.inference_id, inference_id);
378 Self {
379 inference_id,
380 type_assignment: self
381 .type_assignment
382 .iter()
383 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
384 .collect(),
385 const_assignment: self
386 .const_assignment
387 .iter()
388 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
389 .collect(),
390 impl_assignment: self
391 .impl_assignment
392 .iter()
393 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
394 .collect(),
395 impl_vars_trait_item_mappings: self
396 .impl_vars_trait_item_mappings
397 .iter()
398 .map(|(k, mappings)| {
399 (*k, ImplVarTraitItemMappings {
400 types: mappings
401 .types
402 .iter()
403 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
404 .collect(),
405 constants: mappings
406 .constants
407 .iter()
408 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
409 .collect(),
410 impls: mappings
411 .impls
412 .iter()
413 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
414 .collect(),
415 })
416 })
417 .collect(),
418 type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
419 const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
420 impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
421 stable_ptrs: self.stable_ptrs.clone(),
422 pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
423 refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
424 solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
425 ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
426 impl_type_bounds: self
427 .impl_type_bounds
428 .iter()
429 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
430 .collect(),
431 error_status: self.error_status,
432 error: self.error.clone(),
433 consumed_error: self.consumed_error,
434 }
435 }
436 pub fn temporary_clone(&self) -> InferenceData {
437 Self {
438 inference_id: self.inference_id,
439 type_assignment: self.type_assignment.clone(),
440 const_assignment: self.const_assignment.clone(),
441 impl_assignment: self.impl_assignment.clone(),
442 impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
443 type_vars: self.type_vars.clone(),
444 const_vars: self.const_vars.clone(),
445 impl_vars: self.impl_vars.clone(),
446 stable_ptrs: self.stable_ptrs.clone(),
447 pending: self.pending.clone(),
448 refuted: self.refuted.clone(),
449 solved: self.solved.clone(),
450 ambiguous: self.ambiguous.clone(),
451 impl_type_bounds: self.impl_type_bounds.clone(),
452 error_status: self.error_status,
453 error: self.error.clone(),
454 consumed_error: self.consumed_error,
455 }
456 }
457}
458
459pub struct Inference<'db> {
461 db: &'db dyn SemanticGroup,
462 pub data: &'db mut InferenceData,
463}
464
465impl Deref for Inference<'_> {
466 type Target = InferenceData;
467
468 fn deref(&self) -> &Self::Target {
469 self.data
470 }
471}
472impl DerefMut for Inference<'_> {
473 fn deref_mut(&mut self) -> &mut Self::Target {
474 self.data
475 }
476}
477
478impl std::fmt::Debug for Inference<'_> {
479 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480 let x = self.data.debug(self.db.elongate());
481 write!(f, "{x:?}")
482 }
483}
484
485impl<'db> Inference<'db> {
486 fn new(db: &'db dyn SemanticGroup, data: &'db mut InferenceData) -> Self {
487 Self { db, data }
488 }
489
490 fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar {
492 &self.impl_vars[var_id.0]
493 }
494
495 pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId> {
497 self.impl_assignment.get(&var_id).copied()
498 }
499
500 fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId> {
502 self.type_assignment.get(&var_id).copied()
503 }
504
505 pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeId {
508 let var = self.new_type_var_raw(stable_ptr);
509
510 TypeLongId::Var(var).intern(self.db)
511 }
512
513 pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> TypeVar {
516 let var =
517 TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
518 if let Some(stable_ptr) = stable_ptr {
519 self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
520 }
521 self.type_vars.push(var);
522 var
523 }
524
525 pub fn impl_type_assignment(&mut self, impl_type: ImplTypeId) -> TypeId {
528 match self.data.impl_type_bounds.entry(impl_type) {
529 Entry::Occupied(entry) => *entry.get(),
530 Entry::Vacant(entry) => {
531 let inference_id = self.data.inference_id;
532 let id = LocalTypeVarId(self.data.type_vars.len());
533 let var = TypeVar { inference_id, id };
534 let ty = TypeLongId::Var(var).intern(self.db);
535 entry.insert(ty);
536 self.type_vars.push(var);
537 ty
538 }
539 }
540 }
541
542 pub fn finalize_impl_type_bounds(&mut self) {
544 let mut impl_type_bounds = std::mem::take(&mut self.data.impl_type_bounds);
545 impl_type_bounds.retain(|impl_type, ty| {
546 if !matches!(self.rewrite(ty.lookup_intern(self.db)).no_err(), TypeLongId::Var(_)) {
547 return true;
548 }
549
550 self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
551 false
552 });
553 self.data.impl_type_bounds = impl_type_bounds;
554 }
555
556 pub fn new_const_var(
559 &mut self,
560 stable_ptr: Option<SyntaxStablePtrId>,
561 ty: TypeId,
562 ) -> ConstValueId {
563 let var = self.new_const_var_raw(stable_ptr);
564 ConstValue::Var(var, ty).intern(self.db)
565 }
566
567 pub fn new_const_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId>) -> ConstVar {
570 let var = ConstVar {
571 inference_id: self.inference_id,
572 id: LocalConstVarId(self.const_vars.len()),
573 };
574 if let Some(stable_ptr) = stable_ptr {
575 self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
576 }
577 self.const_vars.push(var);
578 var
579 }
580
581 pub fn new_impl_var(
584 &mut self,
585 concrete_trait_id: ConcreteTraitId,
586 stable_ptr: Option<SyntaxStablePtrId>,
587 lookup_context: ImplLookupContext,
588 ) -> ImplId {
589 let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
590 ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
591 }
592
593 fn new_impl_var_raw(
596 &mut self,
597 lookup_context: ImplLookupContext,
598 concrete_trait_id: ConcreteTraitId,
599 stable_ptr: Option<SyntaxStablePtrId>,
600 ) -> LocalImplVarId {
601 let mut lookup_context = lookup_context;
602 lookup_context
603 .insert_module(concrete_trait_id.trait_id(self.db).module_file_id(self.db.upcast()).0);
604
605 let id = LocalImplVarId(self.impl_vars.len());
606 if let Some(stable_ptr) = stable_ptr {
607 self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
608 }
609 let var =
610 ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
611 self.impl_vars.push(var);
612 self.pending.push_back(id);
613 id
614 }
615
616 pub fn solve(&mut self) -> InferenceResult<()> {
621 self.solve_ex().map_err(|(err_set, _)| err_set)
622 }
623
624 fn solve_ex(&mut self) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
626 let mut ambiguous = std::mem::take(&mut self.ambiguous);
627 self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
628 while let Some(var) = self.pending.pop_front() {
629 self.solve_single_pending(var).map_err(|err_set| {
631 (err_set, self.stable_ptrs.get(&InferenceVar::Impl(var)).copied())
632 })?;
633 }
634 Ok(())
635 }
636
637 fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
638 if self.impl_assignment.contains_key(&var) {
639 return Ok(());
640 }
641 let solution = match self.impl_var_solution_set(var)? {
642 SolutionSet::None => {
643 self.refuted.push(var);
644 return Ok(());
645 }
646 SolutionSet::Ambiguous(ambiguity) => {
647 self.ambiguous.push((var, ambiguity));
648 return Ok(());
649 }
650 SolutionSet::Unique(solution) => solution,
651 };
652
653 self.assign_local_impl(var, solution)?;
655
656 self.solved.push(var);
658 let mut ambiguous = std::mem::take(&mut self.ambiguous);
659 self.pending.extend(ambiguous.drain(..).map(|(var, _)| var));
660
661 Ok(())
662 }
663
664 pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<()>> {
667 self.solve()?;
668 if !self.refuted.is_empty() {
669 return Ok(SolutionSet::None);
670 }
671 if let Some((_, ambiguity)) = self.ambiguous.first() {
672 return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
673 }
674 assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
675 Ok(SolutionSet::Unique(()))
676 }
677
678 pub fn finalize_without_reporting(
681 &mut self,
682 ) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId>)> {
683 if self.error_status.is_err() {
684 return Err((ErrorSet, None));
686 }
687
688 let numeric_trait_id = numeric_literal_trait(self.db);
689 let felt_ty = core_felt252_ty(self.db);
690
691 loop {
693 let mut changed = false;
694 self.solve_ex()?;
695 for (var, _) in self.ambiguous.clone() {
696 let impl_var = self.impl_var(var).clone();
697 if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
698 continue;
699 }
700 let ty = extract_matches!(
702 impl_var.concrete_trait_id.generic_args(self.db)[0],
703 GenericArgumentId::Type
704 );
705 if self.rewrite(ty).no_err() == felt_ty {
706 continue;
707 }
708 self.conform_ty(ty, felt_ty).map_err(|err_set| {
709 (err_set, self.stable_ptrs.get(&InferenceVar::Impl(impl_var.id)).copied())
710 })?;
711 changed = true;
712 break;
713 }
714 if !changed {
715 break;
716 }
717 }
718 assert!(
719 self.pending.is_empty(),
720 "pending should all be solved by this point. Guaranteed by solve()."
721 );
722
723 let Some((var, err)) = self.first_undetermined_variable() else {
724 return Ok(());
725 };
726 Err((self.set_error(err), self.stable_ptrs.get(&var).copied()))
727 }
728
729 pub fn finalize(
733 &mut self,
734 diagnostics: &mut SemanticDiagnostics,
735 stable_ptr: SyntaxStablePtrId,
736 ) {
737 if let Err((err_set, err_stable_ptr)) = self.finalize_without_reporting() {
738 let diag = self.report_on_pending_error(
739 err_set,
740 diagnostics,
741 err_stable_ptr.unwrap_or(stable_ptr),
742 );
743
744 let ty_missing = TypeId::missing(self.db, diag);
745 for var in &self.data.type_vars {
746 self.data.type_assignment.entry(var.id).or_insert(ty_missing);
747 }
748 }
749 }
750
751 fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError)> {
755 for (id, var) in self.type_vars.iter().enumerate() {
756 if self.type_assignment(LocalTypeVarId(id)).is_none() {
757 let ty = TypeLongId::Var(*var).intern(self.db);
758 return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
759 }
760 }
761 if let Some(var) = self.refuted.first().copied() {
762 let impl_var = self.impl_var(var).clone();
763 let concrete_trait_id = impl_var.concrete_trait_id;
764 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
765 return Some((
766 InferenceVar::Impl(var),
767 InferenceError::NoImplsFound(concrete_trait_id),
768 ));
769 }
770 if let Some((var, ambiguity)) = self.ambiguous.first() {
771 let var = *var;
772 return Some((InferenceVar::Impl(var), InferenceError::Ambiguity(ambiguity.clone())));
774 }
775 None
776 }
777
778 fn assign_local_impl(
780 &mut self,
781 var: LocalImplVarId,
782 impl_id: ImplId,
783 ) -> InferenceResult<ImplId> {
784 let concrete_trait = impl_id
785 .concrete_trait(self.db)
786 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
787 self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
788 if let Some(other_impl) = self.impl_assignment(var) {
789 return self.conform_impl(impl_id, other_impl);
790 }
791 if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
792 {
793 return Err(self.set_error(InferenceError::Cycle(InferenceVar::Impl(var))));
794 }
795 self.impl_assignment.insert(var, impl_id);
796 if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
797 for (trait_ty, ty) in mappings.types {
798 self.conform_ty(
799 ty,
800 self.db
801 .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_ty, self.db))
802 .map_err(|_| ErrorSet)?,
803 )?;
804 }
805 for (trait_constant, constant_id) in mappings.constants {
806 self.conform_const(
807 constant_id,
808 self.db
809 .impl_constant_concrete_implized_value(ImplConstantId::new(
810 impl_id,
811 trait_constant,
812 self.db,
813 ))
814 .map_err(|_| ErrorSet)?,
815 )?;
816 }
817 for (trait_impl, inner_impl_id) in mappings.impls {
818 self.conform_impl(
819 inner_impl_id,
820 self.db
821 .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
822 .map_err(|_| ErrorSet)?,
823 )?;
824 }
825 }
826 Ok(impl_id)
827 }
828
829 fn assign_impl(&mut self, var_id: ImplVarId, impl_id: ImplId) -> InferenceResult<ImplId> {
831 let var = var_id.lookup_intern(self.db);
832 if var.inference_id != self.inference_id {
833 return Err(self.set_error(InferenceError::ImplKindMismatch {
834 impl0: ImplLongId::ImplVar(var_id).intern(self.db),
835 impl1: impl_id,
836 }));
837 }
838 self.assign_local_impl(var.id, impl_id)
839 }
840
841 fn assign_ty(&mut self, var: TypeVar, ty: TypeId) -> InferenceResult<TypeId> {
844 if var.inference_id != self.inference_id {
845 return Err(self.set_error(InferenceError::TypeKindMismatch {
846 ty0: TypeLongId::Var(var).intern(self.db),
847 ty1: ty,
848 }));
849 }
850 assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
851 let inference_var = InferenceVar::Type(var.id);
852 if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
853 return Err(self.set_error(InferenceError::Cycle(inference_var)));
854 }
855 if let TypeLongId::Var(other) = ty.lookup_intern(self.db) {
857 if other.inference_id == self.inference_id && other.id.0 > var.id.0 {
858 let var_ty = TypeLongId::Var(var).intern(self.db);
859 self.type_assignment.insert(other.id, var_ty);
860 return Ok(var_ty);
861 }
862 }
863 self.type_assignment.insert(var.id, ty);
864 Ok(ty)
865 }
866
867 fn assign_const(&mut self, var: ConstVar, id: ConstValueId) -> InferenceResult<ConstValueId> {
870 if var.inference_id != self.inference_id {
871 return Err(self.set_error(InferenceError::ConstKindMismatch {
872 const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
873 .intern(self.db),
874 const1: id,
875 }));
876 }
877
878 self.const_assignment.insert(var.id, id);
879 Ok(id)
880 }
881
882 fn impl_var_solution_set(
884 &mut self,
885 var: LocalImplVarId,
886 ) -> InferenceResult<SolutionSet<ImplId>> {
887 let impl_var = self.impl_var(var).clone();
888 let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
890 self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
891 let impl_var_trait_item_mappings =
892 self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
893 let solution_set = self.trait_solution_set(
894 concrete_trait_id,
895 impl_var_trait_item_mappings,
896 impl_var.lookup_context,
897 )?;
898 Ok(match solution_set {
899 SolutionSet::None => SolutionSet::None,
900 SolutionSet::Unique((canonical_impl, canonicalizer)) => {
901 SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
902 }
903 SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
904 })
905 }
906
907 pub fn trait_solution_set(
909 &mut self,
910 concrete_trait_id: ConcreteTraitId,
911 impl_var_trait_item_mappings: ImplVarTraitItemMappings,
912 mut lookup_context: ImplLookupContext,
913 ) -> InferenceResult<SolutionSet<(CanonicalImpl, CanonicalMapping)>> {
914 let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
915 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
917 enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
918
919 let generic_args = concrete_trait_id.generic_args(self.db);
921 match generic_args.first() {
922 Some(GenericArgumentId::Type(ty)) => {
923 if let TypeLongId::Var(_) = ty.lookup_intern(self.db) {
924 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
926 }
927 }
928 Some(GenericArgumentId::Impl(imp)) => {
929 if let ImplLongId::ImplVar(_) = imp.lookup_intern(self.db) {
931 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
932 }
933 }
934 Some(GenericArgumentId::Constant(const_value)) => {
935 if let ConstValue::Var(_, _) = const_value.lookup_intern(self.db) {
936 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
938 }
939 }
940 _ => {}
941 };
942
943 let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
944 self.db,
945 self.inference_id,
946 concrete_trait_id,
947 impl_var_trait_item_mappings,
948 );
949 let solution_set = match self.db.canonic_trait_solutions(canonical_trait, lookup_context) {
950 Ok(solution_set) => solution_set,
951 Err(err) => return Err(self.set_error(err)),
952 };
953 match solution_set {
954 SolutionSet::None => Ok(SolutionSet::None),
955 SolutionSet::Unique(canonical_impl) => {
956 Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
957 }
958 SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
959 }
960 }
961
962 fn validate_neg_impls(
966 &mut self,
967 lookup_context: &ImplLookupContext,
968 canonical_impl: CanonicalImpl,
969 ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
970 fn validate_no_solution_set(
972 inference: &mut Inference<'_>,
973 canonical_impl: CanonicalImpl,
974 lookup_context: &ImplLookupContext,
975 negative_impls_concrete_traits: impl Iterator<Item = Maybe<ConcreteTraitId>>,
976 ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
977 for concrete_trait_id in negative_impls_concrete_traits {
978 let concrete_trait_id = concrete_trait_id.map_err(|diag_added| {
979 inference.set_error(InferenceError::Reported(diag_added))
980 })?;
981 for garg in concrete_trait_id.generic_args(inference.db) {
982 let GenericArgumentId::Type(ty) = garg else {
983 continue;
984 };
985 let ty = inference.rewrite(ty).no_err();
986 if !matches!(ty.lookup_intern(inference.db), TypeLongId::Closure(_))
995 && !ty.is_fully_concrete(inference.db)
996 {
997 return Ok(SolutionSet::Ambiguous(
1000 Ambiguity::NegativeImplWithUnresolvedGenericArgs {
1001 impl_id: canonical_impl.0,
1002 ty,
1003 },
1004 ));
1005 }
1006 }
1007
1008 if !matches!(
1009 inference.trait_solution_set(
1010 concrete_trait_id,
1011 ImplVarTraitItemMappings::default(),
1012 lookup_context.clone()
1013 )?,
1014 SolutionSet::None
1015 ) {
1016 return Ok(SolutionSet::None);
1018 }
1019 }
1020
1021 Ok(SolutionSet::Unique(canonical_impl))
1022 }
1023 match canonical_impl.0.lookup_intern(self.db) {
1024 ImplLongId::Concrete(concrete_impl) => {
1025 let mut rewriter = SubstitutionRewriter {
1026 db: self.db,
1027 substitution: &concrete_impl.substitution(self.db).map_err(|diag_added| {
1028 self.set_error(InferenceError::Reported(diag_added))
1029 })?,
1030 };
1031 let generic_params = self
1032 .db
1033 .impl_def_generic_params(concrete_impl.impl_def_id(self.db))
1034 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1035 let concrete_traits = generic_params
1036 .iter()
1037 .filter_map(|generic_param| {
1038 try_extract_matches!(generic_param, GenericParam::NegImpl)
1039 })
1040 .map(|generic_param| {
1041 rewriter
1042 .rewrite(generic_param.clone())
1043 .and_then(|generic_param| generic_param.concrete_trait)
1044 });
1045 validate_no_solution_set(self, canonical_impl, lookup_context, concrete_traits)
1046 }
1047 ImplLongId::GeneratedImpl(generated_impl) => validate_no_solution_set(
1048 self,
1049 canonical_impl,
1050 lookup_context,
1051 generated_impl
1052 .lookup_intern(self.db)
1053 .generic_params
1054 .iter()
1055 .filter_map(|generic_param| {
1056 try_extract_matches!(generic_param, GenericParam::NegImpl)
1057 })
1058 .map(|generic_param| generic_param.concrete_trait),
1059 ),
1060 ImplLongId::GenericParameter(_)
1061 | ImplLongId::ImplVar(_)
1062 | ImplLongId::ImplImpl(_)
1063 | ImplLongId::TraitImpl(_) => Ok(SolutionSet::Unique(canonical_impl)),
1064 }
1065 }
1066
1067 pub fn set_error(&mut self, err: InferenceError) -> ErrorSet {
1074 if self.error_status.is_err() {
1075 return ErrorSet;
1076 }
1077 self.error_status = if let InferenceError::Reported(diag_added) = err {
1078 self.consumed_error = Some(diag_added);
1079 Err(InferenceErrorStatus::Consumed)
1080 } else {
1081 self.error = Some(err);
1082 Err(InferenceErrorStatus::Pending)
1083 };
1084 ErrorSet
1085 }
1086
1087 pub fn is_error_set(&self) -> InferenceResult<()> {
1089 if self.error_status.is_err() { Err(ErrorSet) } else { Ok(()) }
1090 }
1091
1092 pub fn consume_error_without_reporting(&mut self, err_set: ErrorSet) -> Option<InferenceError> {
1098 self.consume_error_inner(err_set, skip_diagnostic())
1099 }
1100
1101 pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1108 self.consume_error_inner(err_set, diag_added);
1109 }
1110
1111 fn consume_error_inner(
1118 &mut self,
1119 _err_set: ErrorSet,
1120 diag_added: DiagnosticAdded,
1121 ) -> Option<InferenceError> {
1122 if self.error_status != Err(InferenceErrorStatus::Pending) {
1123 return None;
1124 }
1126 self.error_status = Err(InferenceErrorStatus::Consumed);
1127 self.consumed_error = Some(diag_added);
1128 mem::take(&mut self.error)
1129 }
1130
1131 pub fn report_on_pending_error(
1137 &mut self,
1138 _err_set: ErrorSet,
1139 diagnostics: &mut SemanticDiagnostics,
1140 stable_ptr: SyntaxStablePtrId,
1141 ) -> DiagnosticAdded {
1142 let Err(state_error) = self.error_status else {
1143 panic!("report_on_pending_error should be called only on error");
1144 };
1145 match state_error {
1146 InferenceErrorStatus::Consumed => self
1147 .consumed_error
1148 .expect("consumed_error is not set although error_status is Err(Consumed)"),
1149 InferenceErrorStatus::Pending => {
1150 let diag_added = match mem::take(&mut self.error)
1151 .expect("error is not set although error_status is Err(Pending)")
1152 {
1153 InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1154 skip_diagnostic()
1159 }
1160 diag => diag.report(diagnostics, stable_ptr),
1161 };
1162
1163 self.error_status = Err(InferenceErrorStatus::Consumed);
1164 self.consumed_error = Some(diag_added);
1165 diag_added
1166 }
1167 }
1168 }
1169
1170 pub fn report_modified_if_pending(
1173 &mut self,
1174 err_set: ErrorSet,
1175 report: impl FnOnce() -> DiagnosticAdded,
1176 ) {
1177 if self.error_status == Err(InferenceErrorStatus::Pending) {
1178 self.consume_reported_error(err_set, report());
1179 }
1180 }
1181}
1182
1183impl<'a> HasDb<&'a dyn SemanticGroup> for Inference<'a> {
1184 fn get_db(&self) -> &'a dyn SemanticGroup {
1185 self.db
1186 }
1187}
1188add_basic_rewrites!(<'a>, Inference<'a>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue);
1189add_expr_rewrites!(<'a>, Inference<'a>, NoError, @exclude);
1190add_rewrite!(<'a>, Inference<'a>, NoError, Ambiguity);
1191impl SemanticRewriter<TypeId, NoError> for Inference<'_> {
1192 fn internal_rewrite(&mut self, value: &mut TypeId) -> Result<RewriteResult, NoError> {
1193 if value.is_var_free(self.db) {
1194 return Ok(RewriteResult::NoChange);
1195 }
1196 value.default_rewrite(self)
1197 }
1198}
1199impl SemanticRewriter<ImplId, NoError> for Inference<'_> {
1200 fn internal_rewrite(&mut self, value: &mut ImplId) -> Result<RewriteResult, NoError> {
1201 if value.is_var_free(self.db) {
1202 return Ok(RewriteResult::NoChange);
1203 }
1204 value.default_rewrite(self)
1205 }
1206}
1207impl SemanticRewriter<TypeLongId, NoError> for Inference<'_> {
1208 fn internal_rewrite(&mut self, value: &mut TypeLongId) -> Result<RewriteResult, NoError> {
1209 match value {
1210 TypeLongId::Var(var) => {
1211 if let Some(type_id) = self.type_assignment.get(&var.id) {
1212 let mut long_type_id = type_id.lookup_intern(self.db);
1213 if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1214 *self.type_assignment.get_mut(&var.id).unwrap() =
1215 long_type_id.clone().intern(self.db);
1216 }
1217 *value = long_type_id;
1218 return Ok(RewriteResult::Modified);
1219 }
1220 }
1221 TypeLongId::ImplType(impl_type_id) => {
1222 if let Some(type_id) = self.impl_type_bounds.get(impl_type_id) {
1223 *value = type_id.lookup_intern(self.db);
1224 self.internal_rewrite(value)?;
1225 return Ok(RewriteResult::Modified);
1226 }
1227 let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1228 let impl_id = impl_type_id.impl_id();
1229 let trait_ty = impl_type_id.ty();
1230 return Ok(match impl_id.lookup_intern(self.db) {
1231 ImplLongId::GenericParameter(_) | ImplLongId::TraitImpl(_) => {
1232 impl_type_id_rewrite_result
1233 }
1234 ImplLongId::ImplImpl(impl_impl) => {
1235 assert!(impl_impl.impl_id().is_var_free(self.db));
1238 impl_type_id_rewrite_result
1239 }
1240 ImplLongId::Concrete(_) => {
1241 if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1242 impl_id, trait_ty, self.db,
1243 )) {
1244 *value = self.rewrite(ty).no_err().lookup_intern(self.db);
1245 RewriteResult::Modified
1246 } else {
1247 impl_type_id_rewrite_result
1248 }
1249 }
1250 ImplLongId::ImplVar(var) => {
1251 *value = self.rewritten_impl_type(var, trait_ty).lookup_intern(self.db);
1252 return Ok(RewriteResult::Modified);
1253 }
1254 ImplLongId::GeneratedImpl(generated) => {
1255 *value = self
1256 .rewrite(
1257 *generated
1258 .lookup_intern(self.db)
1259 .impl_items
1260 .0
1261 .get(&impl_type_id.ty())
1262 .unwrap(),
1263 )
1264 .no_err()
1265 .lookup_intern(self.db);
1266 RewriteResult::Modified
1267 }
1268 });
1269 }
1270 _ => {}
1271 }
1272 value.default_rewrite(self)
1273 }
1274}
1275impl SemanticRewriter<ConstValue, NoError> for Inference<'_> {
1276 fn internal_rewrite(&mut self, value: &mut ConstValue) -> Result<RewriteResult, NoError> {
1277 match value {
1278 ConstValue::Var(var, _) => {
1279 return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1280 let mut const_value = const_value_id.lookup_intern(self.db);
1281 if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1282 *self.const_assignment.get_mut(&var.id).unwrap() =
1283 const_value.clone().intern(self.db);
1284 }
1285 *value = const_value;
1286 RewriteResult::Modified
1287 } else {
1288 RewriteResult::NoChange
1289 });
1290 }
1291 ConstValue::ImplConstant(impl_constant_id) => {
1292 let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1293 let impl_id = impl_constant_id.impl_id();
1294 let trait_constant = impl_constant_id.trait_constant_id();
1295 return Ok(match impl_id.lookup_intern(self.db) {
1296 ImplLongId::GenericParameter(_)
1297 | ImplLongId::TraitImpl(_)
1298 | ImplLongId::GeneratedImpl(_) => impl_constant_id_rewrite_result,
1299 ImplLongId::ImplImpl(impl_impl) => {
1300 assert!(impl_impl.impl_id().is_var_free(self.db));
1303 impl_constant_id_rewrite_result
1304 }
1305 ImplLongId::Concrete(_) => {
1306 if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1307 ImplConstantId::new(impl_id, trait_constant, self.db),
1308 ) {
1309 *value = self.rewrite(constant).no_err().lookup_intern(self.db);
1310 RewriteResult::Modified
1311 } else {
1312 impl_constant_id_rewrite_result
1313 }
1314 }
1315 ImplLongId::ImplVar(var) => {
1316 *value = self
1317 .rewritten_impl_constant(var, trait_constant)
1318 .lookup_intern(self.db);
1319 return Ok(RewriteResult::Modified);
1320 }
1321 });
1322 }
1323 _ => {}
1324 }
1325 value.default_rewrite(self)
1326 }
1327}
1328impl SemanticRewriter<ImplLongId, NoError> for Inference<'_> {
1329 fn internal_rewrite(&mut self, value: &mut ImplLongId) -> Result<RewriteResult, NoError> {
1330 match value {
1331 ImplLongId::ImplVar(var) => {
1332 let long_id = var.lookup_intern(self.db);
1333 let impl_var_id = long_id.id;
1335 if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1336 let mut long_impl_id = impl_id.lookup_intern(self.db);
1337 if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1338 *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1339 long_impl_id.clone().intern(self.db);
1340 }
1341 *value = long_impl_id;
1342 return Ok(RewriteResult::Modified);
1343 }
1344 }
1345 ImplLongId::ImplImpl(impl_impl_id) => {
1346 let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1347 let impl_id = impl_impl_id.impl_id();
1348 return Ok(match impl_id.lookup_intern(self.db) {
1349 ImplLongId::GenericParameter(_)
1350 | ImplLongId::TraitImpl(_)
1351 | ImplLongId::GeneratedImpl(_) => impl_impl_id_rewrite_result,
1352 ImplLongId::ImplImpl(impl_impl) => {
1353 assert!(impl_impl.impl_id().is_var_free(self.db));
1356 impl_impl_id_rewrite_result
1357 }
1358 ImplLongId::Concrete(_) => {
1359 if let Ok(ty) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1360 *value = self.rewrite(ty).no_err().lookup_intern(self.db);
1361 RewriteResult::Modified
1362 } else {
1363 impl_impl_id_rewrite_result
1364 }
1365 }
1366 ImplLongId::ImplVar(var) => {
1367 if let Ok(concrete_trait_impl) =
1368 impl_impl_id.concrete_trait_impl_id(self.db)
1369 {
1370 *value = self
1371 .rewritten_impl_impl(var, concrete_trait_impl)
1372 .lookup_intern(self.db);
1373 return Ok(RewriteResult::Modified);
1374 } else {
1375 impl_impl_id_rewrite_result
1376 }
1377 }
1378 });
1379 }
1380
1381 _ => {}
1382 }
1383 if value.is_var_free(self.db) {
1384 return Ok(RewriteResult::NoChange);
1385 }
1386 value.default_rewrite(self)
1387 }
1388}
1389
1390struct InferenceIdReplacer<'a> {
1391 db: &'a dyn SemanticGroup,
1392 from_inference_id: InferenceId,
1393 to_inference_id: InferenceId,
1394}
1395impl<'a> InferenceIdReplacer<'a> {
1396 fn new(
1397 db: &'a dyn SemanticGroup,
1398 from_inference_id: InferenceId,
1399 to_inference_id: InferenceId,
1400 ) -> Self {
1401 Self { db, from_inference_id, to_inference_id }
1402 }
1403}
1404impl<'a> HasDb<&'a dyn SemanticGroup> for InferenceIdReplacer<'a> {
1405 fn get_db(&self) -> &'a dyn SemanticGroup {
1406 self.db
1407 }
1408}
1409add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1410add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1411add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity);
1412impl SemanticRewriter<InferenceId, NoError> for InferenceIdReplacer<'_> {
1413 fn internal_rewrite(&mut self, value: &mut InferenceId) -> Result<RewriteResult, NoError> {
1414 if value == &self.from_inference_id {
1415 *value = self.to_inference_id;
1416 Ok(RewriteResult::Modified)
1417 } else {
1418 Ok(RewriteResult::NoChange)
1419 }
1420 }
1421}