1use std::collections::{BTreeMap, HashMap};
4use std::hash::Hash;
5use std::ops::{Deref, DerefMut};
6use std::sync::Arc;
7
8use cairo_lang_debug::DebugWithDb;
9use cairo_lang_defs::ids::{
10 ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericKind,
11 GenericParamId, GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LocalVarId,
12 LookupItemId, MacroCallId, MemberId, NamedLanguageElementId, ParamId, StructId,
13 TraitConstantId, TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
14};
15use cairo_lang_diagnostics::{DiagnosticAdded, skip_diagnostic};
16use cairo_lang_proc_macros::{DebugWithDb, HeapSize, SemanticObject};
17use cairo_lang_syntax::node::TypedStablePtr;
18use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
19use cairo_lang_utils::deque::Deque;
20use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
21use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
22use cairo_lang_utils::{Intern, define_short_id, extract_matches};
23use salsa::Database;
24
25use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
26use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
27use crate::corelib::CorelibSemantic;
28use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
29use crate::expr::inference::canonic::ResultNoErrEx;
30use crate::expr::inference::conform::InferenceConform;
31use crate::expr::inference::solver::SemanticSolver;
32use crate::expr::objects::*;
33use crate::expr::pattern::*;
34use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
35use crate::items::functions::{
36 ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
37 GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
38 ImplGenericFunctionWithBodyId,
39};
40use crate::items::generics::{
41 GenericParamConst, GenericParamImpl, GenericParamSemantic, GenericParamType,
42};
43use crate::items::imp::{
44 GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
45 ImplLookupContextId, ImplSemantic, NegativeImplId, NegativeImplLongId,
46 UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
47};
48use crate::items::trt::{
49 ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId, ConcreteTraitTypeId,
50 ConcreteTraitTypeLongId,
51};
52use crate::substitution::{GenericSubstitution, HasDb, RewriteResult, SemanticRewriter};
53use crate::types::{
54 ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
55 ImplTypeById, ImplTypeId,
56};
57use crate::{
58 ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
59 ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
60 FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
61 Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
62 add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
63};
64
65pub mod canonic;
66pub mod conform;
67pub mod infers;
68pub mod solver;
69
70#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, HeapSize, salsa::Update)]
73pub struct TypeVar<'db> {
74 pub inference_id: InferenceId<'db>,
75 pub id: LocalTypeVarId,
76}
77
78#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, HeapSize, salsa::Update)]
81pub struct ConstVar<'db> {
82 pub inference_id: InferenceId<'db>,
83 pub id: LocalConstVarId,
84}
85
86#[derive(
88 Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, HeapSize, salsa::Update,
89)]
90#[debug_db(dyn Database)]
91pub enum InferenceId<'db> {
92 LookupItemDeclaration(LookupItemId<'db>),
93 LookupItemGenerics(LookupItemId<'db>),
94 LookupItemDefinition(LookupItemId<'db>),
95 ImplDefTrait(ImplDefId<'db>),
96 ImplAliasImplDef(ImplAliasId<'db>),
97 GenericParam(GenericParamId<'db>),
98 GenericImplParamTrait(GenericParamId<'db>),
99 GlobalUseStar(GlobalUseId<'db>),
100 MacroCall(MacroCallId<'db>),
101 Canonical,
102 NoContext,
104}
105
106#[derive(
109 Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, HeapSize, salsa::Update,
110)]
111#[debug_db(dyn Database)]
112pub struct ImplVar<'db> {
113 pub inference_id: InferenceId<'db>,
114 #[dont_rewrite]
115 pub id: LocalImplVarId,
116 pub concrete_trait_id: ConcreteTraitId<'db>,
117 #[dont_rewrite]
118 pub lookup_context: ImplLookupContextId<'db>,
119}
120impl<'db> ImplVar<'db> {
121 pub fn intern(&self, db: &'db dyn Database) -> ImplVarId<'db> {
122 self.clone().intern(db)
123 }
124}
125
126#[derive(
128 Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, HeapSize, salsa::Update,
129)]
130#[debug_db(dyn Database)]
131pub struct NegativeImplVar<'db> {
132 pub inference_id: InferenceId<'db>,
133 #[dont_rewrite]
134 pub id: LocalNegativeImplVarId,
135 pub concrete_trait_id: ConcreteTraitId<'db>,
136 #[dont_rewrite]
137 pub lookup_context: ImplLookupContextId<'db>,
138}
139
140#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
141pub struct LocalTypeVarId(pub usize);
142#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
143pub struct LocalImplVarId(pub usize);
144
145#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
146pub struct LocalNegativeImplVarId(pub usize);
147
148#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
149pub struct LocalConstVarId(pub usize);
150
151define_short_id!(ImplVarId, ImplVar<'db>);
152impl<'db> ImplVarId<'db> {
153 pub fn id(&self, db: &dyn Database) -> LocalImplVarId {
154 self.long(db).id
155 }
156 pub fn concrete_trait_id(&self, db: &'db dyn Database) -> ConcreteTraitId<'db> {
157 self.long(db).concrete_trait_id
158 }
159 pub fn lookup_context(&self, db: &'db dyn Database) -> ImplLookupContextId<'db> {
160 self.long(db).lookup_context
161 }
162}
163semantic_object_for_id!(ImplVarId, ImplVar<'a>);
164
165define_short_id!(NegativeImplVarId, NegativeImplVar<'db>);
166impl<'db> NegativeImplVarId<'db> {
167 pub fn id(&self, db: &dyn Database) -> LocalNegativeImplVarId {
168 self.long(db).id
169 }
170 pub fn concrete_trait_id(&self, db: &'db dyn Database) -> ConcreteTraitId<'db> {
171 self.long(db).concrete_trait_id
172 }
173 pub fn lookup_context(&self, db: &'db dyn Database) -> ImplLookupContextId<'db> {
174 self.long(db).lookup_context
175 }
176}
177semantic_object_for_id!(NegativeImplVarId, NegativeImplVar<'a>);
178
179#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject, salsa::Update)]
180pub enum InferenceVar {
181 Type(LocalTypeVarId),
182 Const(LocalConstVarId),
183 Impl(LocalImplVarId),
184 NegativeImpl(LocalNegativeImplVarId),
185}
186
187#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb, salsa::Update)]
189#[debug_db(dyn Database)]
190pub enum InferenceError<'db> {
191 Reported(DiagnosticAdded),
193 Cycle(InferenceVar),
194 TypeKindMismatch {
195 ty0: TypeId<'db>,
196 ty1: TypeId<'db>,
197 },
198 ConstKindMismatch {
199 const0: ConstValueId<'db>,
200 const1: ConstValueId<'db>,
201 },
202 ImplKindMismatch {
203 impl0: ImplId<'db>,
204 impl1: ImplId<'db>,
205 },
206 NegativeImplKindMismatch {
207 impl0: NegativeImplId<'db>,
208 impl1: NegativeImplId<'db>,
209 },
210 GenericArgMismatch {
211 garg0: GenericArgumentId<'db>,
212 garg1: GenericArgumentId<'db>,
213 },
214 TraitMismatch {
215 trt0: TraitId<'db>,
216 trt1: TraitId<'db>,
217 },
218 ImplTypeMismatch {
219 impl_id: ImplId<'db>,
220 trait_type_id: TraitTypeId<'db>,
221 ty0: TypeId<'db>,
222 ty1: TypeId<'db>,
223 },
224 GenericFunctionMismatch {
225 func0: GenericFunctionId<'db>,
226 func1: GenericFunctionId<'db>,
227 },
228 ConstNotInferred,
229 NoImplsFound(ConcreteTraitId<'db>),
232 NoNegativeImplsFound(ConcreteTraitId<'db>),
233 Ambiguity(Ambiguity<'db>),
234 TypeNotInferred(TypeId<'db>),
235}
236impl<'db> InferenceError<'db> {
237 pub fn format(&self, db: &dyn Database) -> String {
238 match self {
239 InferenceError::Reported(_) => "Inference error occurred.".into(),
240 InferenceError::Cycle(_var) => "Inference cycle detected".into(),
241 InferenceError::TypeKindMismatch { ty0, ty1 } => {
242 format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
243 }
244 InferenceError::ConstKindMismatch { const0, const1 } => {
245 format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
246 }
247 InferenceError::ImplKindMismatch { impl0, impl1 } => {
248 format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
249 }
250 InferenceError::NegativeImplKindMismatch { impl0, impl1 } => {
251 format!(
252 "Negative impl mismatch: `{:?}` and `{:?}`.",
253 impl0.debug(db),
254 impl1.debug(db)
255 )
256 }
257 InferenceError::GenericArgMismatch { garg0, garg1 } => {
258 format!(
259 "Generic arg mismatch: `{:?}` and `{:?}`.",
260 garg0.debug(db),
261 garg1.debug(db)
262 )
263 }
264 InferenceError::TraitMismatch { trt0, trt1 } => {
265 format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
266 }
267 InferenceError::ConstNotInferred => "Failed to infer constant.".into(),
268 InferenceError::NoImplsFound(concrete_trait_id) => {
269 let info = db.core_info();
270 let trait_id = concrete_trait_id.trait_id(db);
271 if trait_id == info.numeric_literal_trt {
272 let generic_type = extract_matches!(
273 concrete_trait_id.generic_args(db)[0],
274 GenericArgumentId::Type
275 );
276 return format!(
277 "Mismatched types. The type `{:?}` cannot be created from a numeric \
278 literal.",
279 generic_type.debug(db)
280 );
281 } else if trait_id == info.string_literal_trt {
282 let generic_type = extract_matches!(
283 concrete_trait_id.generic_args(db)[0],
284 GenericArgumentId::Type
285 );
286 return format!(
287 "Mismatched types. The type `{:?}` cannot be created from a string \
288 literal.",
289 generic_type.debug(db)
290 );
291 }
292 format!(
293 "Trait has no implementation in context: {:?}.",
294 concrete_trait_id.debug(db)
295 )
296 }
297 InferenceError::NoNegativeImplsFound(concrete_trait_id) => {
298 format!("Trait has implementation in context: {:?}.", concrete_trait_id.debug(db))
299 }
300 InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
301 InferenceError::TypeNotInferred(ty) => {
302 format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
303 }
304 InferenceError::GenericFunctionMismatch { func0, func1 } => {
305 format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
306 }
307 InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 } => {
308 format!(
309 "`{}::{}` type mismatch: `{:?}` and `{:?}`.",
310 impl_id.format(db),
311 trait_type_id.name(db).long(db),
312 ty0.debug(db),
313 ty1.debug(db)
314 )
315 }
316 }
317 }
318}
319
320impl<'db> InferenceError<'db> {
321 pub fn report(
322 &self,
323 diagnostics: &mut SemanticDiagnostics<'db>,
324 stable_ptr: SyntaxStablePtrId<'db>,
325 ) -> DiagnosticAdded {
326 match self {
327 InferenceError::Reported(diagnostic_added) => *diagnostic_added,
328 _ => diagnostics
329 .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
330 }
331 }
332}
333
334#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
339pub struct ErrorSet;
340
341pub type InferenceResult<T> = Result<T, ErrorSet>;
342
343#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
344enum InferenceErrorStatus<'db> {
345 Pending(PendingInferenceError<'db>),
347 Consumed(DiagnosticAdded),
349}
350
351#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
353struct PendingInferenceError<'db> {
354 err: InferenceError<'db>,
356 stable_ptr: Option<SyntaxStablePtrId<'db>>,
358}
359
360#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, SemanticObject, salsa::Update)]
362pub struct ImplVarTraitItemMappings<'db> {
363 types: OrderedHashMap<TraitTypeId<'db>, TypeId<'db>>,
365 constants: OrderedHashMap<TraitConstantId<'db>, ConstValueId<'db>>,
367 impls: OrderedHashMap<TraitImplId<'db>, ImplId<'db>>,
369}
370
371impl ImplVarTraitItemMappings<'_> {
372 pub fn is_empty(&self) -> bool {
374 self.types.is_empty() && self.constants.is_empty() && self.impls.is_empty()
375 }
376}
377
378#[derive(Debug, DebugWithDb, PartialEq, Eq, salsa::Update)]
380#[debug_db(dyn Database)]
381pub struct InferenceData<'db> {
382 pub inference_id: InferenceId<'db>,
383 pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId<'db>>,
385 pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId<'db>>,
387 pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId<'db>>,
389 pub negative_impl_assignment: OrderedHashMap<LocalNegativeImplVarId, NegativeImplId<'db>>,
391 pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings<'db>>,
394 pub type_vars: Vec<TypeVar<'db>>,
396 pub const_vars: Vec<ConstVar<'db>>,
398 pub impl_vars: Vec<ImplVar<'db>>,
400 pub negative_impl_vars: Vec<NegativeImplVar<'db>>,
402 pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId<'db>>,
404 pending: Deque<LocalImplVarId>,
406 negative_pending: Deque<LocalNegativeImplVarId>,
408 refuted: Vec<LocalImplVarId>,
410 negative_refuted: Vec<LocalNegativeImplVarId>,
412 solved: Vec<LocalImplVarId>,
414 ambiguous: Vec<(LocalImplVarId, Ambiguity<'db>)>,
416 negative_ambiguous: Vec<(LocalNegativeImplVarId, Ambiguity<'db>)>,
418 pub impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
420
421 error_status: Result<(), InferenceErrorStatus<'db>>,
423}
424impl<'db> InferenceData<'db> {
425 pub fn new(inference_id: InferenceId<'db>) -> Self {
426 Self {
427 inference_id,
428 type_assignment: OrderedHashMap::default(),
429 impl_assignment: OrderedHashMap::default(),
430 const_assignment: OrderedHashMap::default(),
431 negative_impl_assignment: OrderedHashMap::default(),
432 impl_vars_trait_item_mappings: HashMap::new(),
433 type_vars: Vec::new(),
434 impl_vars: Vec::new(),
435 const_vars: Vec::new(),
436 negative_impl_vars: Vec::new(),
437 stable_ptrs: HashMap::new(),
438 pending: Deque::new(),
439 negative_pending: Deque::new(),
440 refuted: Vec::new(),
441 negative_refuted: Vec::new(),
442 solved: Vec::new(),
443 ambiguous: Vec::new(),
444 negative_ambiguous: Vec::new(),
445 impl_type_bounds: Default::default(),
446 error_status: Ok(()),
447 }
448 }
449 pub fn inference<'r>(&'r mut self, db: &'db dyn Database) -> Inference<'db, 'r> {
450 Inference::new(db, self)
451 }
452 pub fn clone_with_inference_id(
453 &self,
454 db: &'db dyn Database,
455 inference_id: InferenceId<'db>,
456 ) -> InferenceData<'db> {
457 let mut inference_id_replacer =
458 InferenceIdReplacer::new(db, self.inference_id, inference_id);
459 Self {
460 inference_id,
461 type_assignment: self
462 .type_assignment
463 .iter()
464 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
465 .collect(),
466 const_assignment: self
467 .const_assignment
468 .iter()
469 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
470 .collect(),
471 impl_assignment: self
472 .impl_assignment
473 .iter()
474 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
475 .collect(),
476 negative_impl_assignment: self
477 .negative_impl_assignment
478 .iter()
479 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
480 .collect(),
481 impl_vars_trait_item_mappings: self
482 .impl_vars_trait_item_mappings
483 .iter()
484 .map(|(k, mappings)| {
485 (
486 *k,
487 ImplVarTraitItemMappings {
488 types: mappings
489 .types
490 .iter()
491 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
492 .collect(),
493 constants: mappings
494 .constants
495 .iter()
496 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
497 .collect(),
498 impls: mappings
499 .impls
500 .iter()
501 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
502 .collect(),
503 },
504 )
505 })
506 .collect(),
507 type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
508 const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
509 impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
510 negative_impl_vars: inference_id_replacer
511 .rewrite(self.negative_impl_vars.clone())
512 .no_err(),
513 stable_ptrs: self.stable_ptrs.clone(),
514 pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
515 negative_pending: inference_id_replacer.rewrite(self.negative_pending.clone()).no_err(),
516 refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
517 negative_refuted: inference_id_replacer.rewrite(self.negative_refuted.clone()).no_err(),
518 solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
519 ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
520 negative_ambiguous: inference_id_replacer
521 .rewrite(self.negative_ambiguous.clone())
522 .no_err(),
523 impl_type_bounds: self.impl_type_bounds.clone(),
525
526 error_status: self.error_status.clone(),
527 }
528 }
529 pub fn temporary_clone(&self) -> InferenceData<'db> {
530 Self {
531 inference_id: self.inference_id,
532 type_assignment: self.type_assignment.clone(),
533 const_assignment: self.const_assignment.clone(),
534 impl_assignment: self.impl_assignment.clone(),
535 negative_impl_assignment: self.negative_impl_assignment.clone(),
536 impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
537 type_vars: self.type_vars.clone(),
538 const_vars: self.const_vars.clone(),
539 impl_vars: self.impl_vars.clone(),
540 negative_impl_vars: self.negative_impl_vars.clone(),
541 stable_ptrs: self.stable_ptrs.clone(),
542 pending: self.pending.clone(),
543 negative_pending: self.negative_pending.clone(),
544 refuted: self.refuted.clone(),
545 negative_refuted: self.negative_refuted.clone(),
546 solved: self.solved.clone(),
547 ambiguous: self.ambiguous.clone(),
548 negative_ambiguous: self.negative_ambiguous.clone(),
549 impl_type_bounds: self.impl_type_bounds.clone(),
550 error_status: self.error_status.clone(),
551 }
552 }
553}
554
555pub struct Inference<'db, 'id> {
557 db: &'db dyn Database,
558 pub data: &'id mut InferenceData<'db>,
559}
560
561impl<'db, 'id> Deref for Inference<'db, 'id> {
562 type Target = InferenceData<'db>;
563
564 fn deref(&self) -> &Self::Target {
565 self.data
566 }
567}
568impl DerefMut for Inference<'_, '_> {
569 fn deref_mut(&mut self) -> &mut Self::Target {
570 self.data
571 }
572}
573
574impl std::fmt::Debug for Inference<'_, '_> {
575 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
576 let x = self.data.debug(self.db);
577 write!(f, "{x:?}")
578 }
579}
580
581impl<'db, 'id> Inference<'db, 'id> {
582 fn new(db: &'db dyn Database, data: &'id mut InferenceData<'db>) -> Self {
583 Self { db, data }
584 }
585
586 fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar<'db> {
588 &self.impl_vars[var_id.0]
589 }
590
591 pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId<'db>> {
593 self.impl_assignment.get(&var_id).copied()
594 }
595
596 fn negative_impl_var(&self, var_id: LocalNegativeImplVarId) -> &NegativeImplVar<'db> {
598 &self.negative_impl_vars[var_id.0]
599 }
600
601 pub fn negative_impl_assignment(
603 &self,
604 var_id: LocalNegativeImplVarId,
605 ) -> Option<NegativeImplId<'db>> {
606 self.negative_impl_assignment.get(&var_id).copied()
607 }
608
609 fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId<'db>> {
611 self.type_assignment.get(&var_id).copied()
612 }
613
614 pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId<'db>>) -> TypeId<'db> {
617 let var = self.new_type_var_raw(stable_ptr);
618
619 TypeLongId::Var(var).intern(self.db)
620 }
621
622 pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId<'db>>) -> TypeVar<'db> {
625 let var =
626 TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
627 if let Some(stable_ptr) = stable_ptr {
628 self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
629 }
630 self.type_vars.push(var);
631 var
632 }
633
634 pub fn set_impl_type_bounds(
637 &mut self,
638 impl_type_bounds: OrderedHashMap<ImplTypeId<'db>, TypeId<'db>>,
639 ) {
640 let impl_type_bounds_finalized = impl_type_bounds
641 .iter()
642 .filter_map(|(impl_type, ty)| {
643 let rewritten_type = self.rewrite(ty.long(self.db).clone()).no_err();
644 if !matches!(rewritten_type, TypeLongId::Var(_)) {
645 return Some(((*impl_type).into(), rewritten_type.intern(self.db)));
646 }
647 self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
650 None
651 })
652 .collect();
653
654 self.data.impl_type_bounds = Arc::new(impl_type_bounds_finalized);
655 }
656
657 pub fn new_const_var(
660 &mut self,
661 stable_ptr: Option<SyntaxStablePtrId<'db>>,
662 ty: TypeId<'db>,
663 ) -> ConstValueId<'db> {
664 let var = self.new_const_var_raw(stable_ptr);
665 ConstValue::Var(var, ty).intern(self.db)
666 }
667
668 pub fn new_const_var_raw(
671 &mut self,
672 stable_ptr: Option<SyntaxStablePtrId<'db>>,
673 ) -> ConstVar<'db> {
674 let var = ConstVar {
675 inference_id: self.inference_id,
676 id: LocalConstVarId(self.const_vars.len()),
677 };
678 if let Some(stable_ptr) = stable_ptr {
679 self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
680 }
681 self.const_vars.push(var);
682 var
683 }
684
685 pub fn new_impl_var(
688 &mut self,
689 concrete_trait_id: ConcreteTraitId<'db>,
690 stable_ptr: Option<SyntaxStablePtrId<'db>>,
691 lookup_context: ImplLookupContextId<'db>,
692 ) -> ImplId<'db> {
693 let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
694 ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
695 }
696
697 fn new_impl_var_raw(
700 &mut self,
701 lookup_context: ImplLookupContextId<'db>,
702 concrete_trait_id: ConcreteTraitId<'db>,
703 stable_ptr: Option<SyntaxStablePtrId<'db>>,
704 ) -> LocalImplVarId {
705 let id = LocalImplVarId(self.impl_vars.len());
706 if let Some(stable_ptr) = stable_ptr {
707 self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
708 }
709 let var =
710 ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
711 self.impl_vars.push(var);
712 self.pending.push_back(id);
713 id
714 }
715
716 pub fn new_negative_impl_var(
719 &mut self,
720 concrete_trait_id: ConcreteTraitId<'db>,
721 stable_ptr: Option<SyntaxStablePtrId<'db>>,
722 lookup_context: ImplLookupContextId<'db>,
723 ) -> NegativeImplId<'db> {
724 let var = self.new_negative_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
725 NegativeImplLongId::NegativeImplVar(self.negative_impl_var(var).clone().intern(self.db))
726 .intern(self.db)
727 }
728
729 fn new_negative_impl_var_raw(
732 &mut self,
733 lookup_context: ImplLookupContextId<'db>,
734 concrete_trait_id: ConcreteTraitId<'db>,
735 stable_ptr: Option<SyntaxStablePtrId<'db>>,
736 ) -> LocalNegativeImplVarId {
737 let id = LocalNegativeImplVarId(self.negative_impl_vars.len());
738 if let Some(stable_ptr) = stable_ptr {
739 self.stable_ptrs.insert(InferenceVar::NegativeImpl(id), stable_ptr);
740 }
741 let var = NegativeImplVar {
742 inference_id: self.inference_id,
743 id,
744 concrete_trait_id,
745 lookup_context,
746 };
747 self.negative_impl_vars.push(var);
748 self.negative_pending.push_back(id);
749 id
750 }
751
752 pub fn solve(&mut self) -> InferenceResult<()> {
757 let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
758 self.pending.extend(ambiguous.map(|(var, _)| var));
759 while let Some(var) = self.pending.pop_front() {
760 self.solve_single_pending(var).inspect_err(|_err_set| {
762 self.add_error_stable_ptr(InferenceVar::Impl(var));
763 })?;
764 }
765 while let Some(var) = self.negative_pending.pop_front() {
766 self.solve_single_negative_pending(var).inspect_err(|_err_set| {
768 self.add_error_stable_ptr(InferenceVar::NegativeImpl(var));
769 })?;
770 }
771 Ok(())
772 }
773
774 fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
775 if self.impl_assignment.contains_key(&var) {
776 return Ok(());
777 }
778 let solution = match self.impl_var_solution_set(var)? {
779 SolutionSet::None => {
780 self.refuted.push(var);
781 return Ok(());
782 }
783 SolutionSet::Ambiguous(ambiguity) => {
784 self.ambiguous.push((var, ambiguity));
785 return Ok(());
786 }
787 SolutionSet::Unique(solution) => solution,
788 };
789
790 self.assign_local_impl(var, solution)?;
792
793 self.solved.push(var);
795 let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
796 self.pending.extend(ambiguous.map(|(var, _)| var));
797
798 let negative_ambiguous = std::mem::take(&mut self.negative_ambiguous).into_iter();
799 self.negative_pending.extend(negative_ambiguous.map(|(var, _)| var));
800
801 Ok(())
802 }
803
804 fn solve_single_negative_pending(
806 &mut self,
807 var: LocalNegativeImplVarId,
808 ) -> InferenceResult<()> {
809 if self.negative_impl_assignment.contains_key(&var) {
810 return Ok(());
811 }
812
813 let solution = match self.negative_impl_var_solution_set(var)? {
814 SolutionSet::None => {
815 self.negative_refuted.push(var);
816 return Ok(());
817 }
818 SolutionSet::Ambiguous(ambiguity) => {
819 self.negative_ambiguous.push((var, ambiguity));
820 return Ok(());
821 }
822 SolutionSet::Unique(solution) => solution,
823 };
824
825 self.assign_local_negative_impl(var, solution)?;
827
828 Ok(())
829 }
830
831 pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<'db, ()>> {
834 self.solve()?;
835 if !self.refuted.is_empty() {
836 return Ok(SolutionSet::None);
837 }
838 if !self.negative_refuted.is_empty() {
839 return Ok(SolutionSet::None);
840 }
841 if let Some((_, ambiguity)) = self.ambiguous.first() {
842 return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
843 }
844 if let Some((_, ambiguity)) = self.negative_ambiguous.first() {
845 return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
846 }
847 assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
848 assert!(self.negative_pending.is_empty(), "solution() called on an unsolved solver");
849 Ok(SolutionSet::Unique(()))
850 }
851
852 pub fn finalize_without_reporting(&mut self) -> Result<(), ErrorSet> {
855 if self.error_status.is_err() {
856 return Err(ErrorSet);
857 }
858 let info = self.db.core_info();
859 let numeric_trait_id = info.numeric_literal_trt;
860 let felt_ty = info.felt252;
861
862 loop {
864 let mut changed = false;
865 self.solve()?;
866 for (var, _) in self.ambiguous.clone() {
867 let impl_var = self.impl_var(var).clone();
868 if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
869 continue;
870 }
871 let ty = extract_matches!(
873 impl_var.concrete_trait_id.generic_args(self.db)[0],
874 GenericArgumentId::Type
875 );
876 if self.rewrite(ty).no_err() == felt_ty {
877 continue;
878 }
879 self.conform_ty(ty, felt_ty).inspect_err(|_err_set| {
880 self.add_error_stable_ptr(InferenceVar::Impl(impl_var.id));
881 })?;
882 changed = true;
883 break;
884 }
885 if !changed {
886 break;
887 }
888 }
889 assert!(
890 self.pending.is_empty(),
891 "pending should all be solved by this point. Guaranteed by solve()."
892 );
893
894 let Some((var, err)) = self.first_undetermined_variable() else {
895 return Ok(());
896 };
897 Err(self.set_error_on_var(err, var))
898 }
899
900 pub fn finalize<'m>(
904 &'m mut self,
905 diagnostics: &mut SemanticDiagnostics<'db>,
906 stable_ptr: SyntaxStablePtrId<'db>,
907 ) {
908 if let Err(err_set) = self.finalize_without_reporting() {
909 let diag = self.report_on_pending_error(err_set, diagnostics, stable_ptr);
910
911 let ty_missing = TypeId::missing(self.db, diag);
912 for var in &self.data.type_vars {
913 self.data.type_assignment.entry(var.id).or_insert(ty_missing);
914 }
915 }
916 }
917
918 fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError<'db>)> {
922 if let Some(var) = self.refuted.first().copied() {
923 let impl_var = self.impl_var(var).clone();
924 let concrete_trait_id = impl_var.concrete_trait_id;
925 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
926 return Some((
927 InferenceVar::Impl(var),
928 InferenceError::NoImplsFound(concrete_trait_id),
929 ));
930 }
931 if let Some(var) = self.negative_refuted.first().copied() {
932 let negative_impl_var = self.negative_impl_var(var).clone();
933 let concrete_trait_id = negative_impl_var.concrete_trait_id;
934 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
935 return Some((
936 InferenceVar::NegativeImpl(var),
937 InferenceError::NoNegativeImplsFound(concrete_trait_id),
938 ));
939 }
940
941 let mut fallback_ret = None;
942 if let Some((var, ambiguity)) = self.ambiguous.first() {
943 let ret =
945 Some((InferenceVar::Impl(*var), InferenceError::Ambiguity(ambiguity.clone())));
946 if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
947 return ret;
948 } else {
949 fallback_ret = ret;
950 }
951 }
952 if let Some((var, ambiguity)) = self.negative_ambiguous.first() {
953 let ret = Some((
954 InferenceVar::NegativeImpl(*var),
955 InferenceError::Ambiguity(ambiguity.clone()),
956 ));
957 if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
958 return ret;
959 } else {
960 fallback_ret = ret;
961 }
962 }
963 for (id, var) in self.type_vars.iter().enumerate() {
964 if self.type_assignment(LocalTypeVarId(id)).is_none() {
965 let ty = TypeLongId::Var(*var).intern(self.db);
966 return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
967 }
968 }
969 for (id, var) in self.const_vars.iter().enumerate() {
970 if !self.const_assignment.contains_key(&LocalConstVarId(id)) {
971 let infernence_var = InferenceVar::Const(var.id);
972 return Some((infernence_var, InferenceError::ConstNotInferred));
973 }
974 }
975 fallback_ret
976 }
977
978 fn assign_local_impl(
980 &mut self,
981 var: LocalImplVarId,
982 impl_id: ImplId<'db>,
983 ) -> InferenceResult<ImplId<'db>> {
984 let concrete_trait = impl_id
985 .concrete_trait(self.db)
986 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
987 self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
988 if let Some(other_impl) = self.impl_assignment(var) {
989 return self.conform_impl(impl_id, other_impl);
990 }
991 if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
992 {
993 let inference_var = InferenceVar::Impl(var);
994 return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var));
995 }
996 self.impl_assignment.insert(var, impl_id);
997 if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
998 for (trait_type_id, ty) in mappings.types {
999 let impl_ty = self
1000 .db
1001 .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_type_id, self.db))
1002 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1003 if let Err(err_set) = self.conform_ty(ty, impl_ty) {
1004 let ty0 = self.rewrite(ty).no_err();
1006 let ty1 = self.rewrite(impl_ty).no_err();
1007
1008 let err = InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 };
1009 self.error_status = Err(InferenceErrorStatus::Pending(PendingInferenceError {
1010 err,
1011 stable_ptr: self.stable_ptrs.get(&InferenceVar::Impl(var)).cloned(),
1012 }));
1013 return Err(err_set);
1014 }
1015 }
1016 for (trait_constant, constant_id) in mappings.constants {
1017 let concrete_impl_constant = self
1018 .db
1019 .impl_constant_concrete_implized_value(ImplConstantId::new(
1020 impl_id,
1021 trait_constant,
1022 self.db,
1023 ))
1024 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1025 self.conform_const(constant_id, concrete_impl_constant)?;
1026 }
1027 for (trait_impl, inner_impl_id) in mappings.impls {
1028 let concrete_impl_impl = self
1029 .db
1030 .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
1031 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1032 self.conform_impl(inner_impl_id, concrete_impl_impl)?;
1033 }
1034 }
1035 Ok(impl_id)
1036 }
1037
1038 fn assign_impl(
1040 &mut self,
1041 var_id: ImplVarId<'db>,
1042 impl_id: ImplId<'db>,
1043 ) -> InferenceResult<ImplId<'db>> {
1044 let var = var_id.long(self.db);
1045 if var.inference_id != self.inference_id {
1046 return Err(self.set_error(InferenceError::ImplKindMismatch {
1047 impl0: ImplLongId::ImplVar(var_id).intern(self.db),
1048 impl1: impl_id,
1049 }));
1050 }
1051 self.assign_local_impl(var.id, impl_id)
1052 }
1053
1054 fn assign_local_negative_impl(
1056 &mut self,
1057 var: LocalNegativeImplVarId,
1058 neg_impl_id: NegativeImplId<'db>,
1059 ) -> InferenceResult<NegativeImplId<'db>> {
1060 let concrete_trait = neg_impl_id
1061 .concrete_trait(self.db)
1062 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1063 self.conform_traits(self.negative_impl_var(var).concrete_trait_id, concrete_trait)?;
1064 if let Some(other_impl) = self.negative_impl_assignment(var) {
1065 return self.conform_neg_impl(neg_impl_id, other_impl);
1066 }
1067 if !neg_impl_id.is_var_free(self.db)
1068 && self.negative_impl_contains_var(neg_impl_id, InferenceVar::NegativeImpl(var))
1069 {
1070 return Err(self.set_error(InferenceError::Cycle(InferenceVar::NegativeImpl(var))));
1071 }
1072 self.negative_impl_assignment.insert(var, neg_impl_id);
1073 Ok(neg_impl_id)
1074 }
1075
1076 fn assign_neg_impl(
1079 &mut self,
1080 var_id: NegativeImplVarId<'db>,
1081 neg_impl_id: NegativeImplId<'db>,
1082 ) -> InferenceResult<NegativeImplId<'db>> {
1083 let var = var_id.long(self.db);
1084 if var.inference_id != self.inference_id {
1085 return Err(self.set_error(InferenceError::NegativeImplKindMismatch {
1086 impl0: NegativeImplLongId::NegativeImplVar(var_id).intern(self.db),
1087 impl1: neg_impl_id,
1088 }));
1089 }
1090 self.assign_local_negative_impl(var.id, neg_impl_id)
1091 }
1092
1093 fn assign_ty(&mut self, var: TypeVar<'db>, ty: TypeId<'db>) -> InferenceResult<TypeId<'db>> {
1096 if var.inference_id != self.inference_id {
1097 return Err(self.set_error(InferenceError::TypeKindMismatch {
1098 ty0: TypeLongId::Var(var).intern(self.db),
1099 ty1: ty,
1100 }));
1101 }
1102 assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
1103 let inference_var = InferenceVar::Type(var.id);
1104 if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
1105 return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var));
1106 }
1107 if let TypeLongId::Var(other) = ty.long(self.db)
1109 && other.inference_id == self.inference_id
1110 && other.id.0 > var.id.0
1111 {
1112 let var_ty = TypeLongId::Var(var).intern(self.db);
1113 self.type_assignment.insert(other.id, var_ty);
1114 return Ok(var_ty);
1115 }
1116 self.type_assignment.insert(var.id, ty);
1117 Ok(ty)
1118 }
1119
1120 fn assign_const(
1123 &mut self,
1124 var: ConstVar<'db>,
1125 id: ConstValueId<'db>,
1126 ) -> InferenceResult<ConstValueId<'db>> {
1127 if var.inference_id != self.inference_id {
1128 return Err(self.set_error(InferenceError::ConstKindMismatch {
1129 const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
1130 .intern(self.db),
1131 const1: id,
1132 }));
1133 }
1134
1135 self.const_assignment.insert(var.id, id);
1136 Ok(id)
1137 }
1138
1139 fn impl_var_solution_set(
1141 &mut self,
1142 var: LocalImplVarId,
1143 ) -> InferenceResult<SolutionSet<'db, ImplId<'db>>> {
1144 let impl_var = self.impl_var(var).clone();
1145 let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
1147 self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
1148 let impl_var_trait_item_mappings =
1149 self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
1150 let solution_set = self.trait_solution_set(
1151 concrete_trait_id,
1152 impl_var_trait_item_mappings,
1153 impl_var.lookup_context,
1154 )?;
1155 Ok(match solution_set {
1156 SolutionSet::None => SolutionSet::None,
1157 SolutionSet::Unique((canonical_impl, canonicalizer)) => {
1158 SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
1159 }
1160 SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
1161 })
1162 }
1163
1164 fn negative_impl_var_solution_set(
1166 &mut self,
1167 var: LocalNegativeImplVarId,
1168 ) -> InferenceResult<SolutionSet<'db, NegativeImplId<'db>>> {
1169 let negative_impl_var = self.negative_impl_var(var).clone();
1170 let concrete_trait_id = self.rewrite(negative_impl_var.concrete_trait_id).no_err();
1171
1172 let solution_set =
1173 self.validate_no_solution_set(concrete_trait_id, negative_impl_var.lookup_context)?;
1174 Ok(match solution_set {
1175 SolutionSet::Unique(concrete_trait_id) => {
1176 SolutionSet::Unique(NegativeImplLongId::Solved(concrete_trait_id).intern(self.db))
1177 }
1178 SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
1179 SolutionSet::None => SolutionSet::None,
1180 })
1181 }
1182
1183 fn validate_no_solution_set(
1185 &mut self,
1186 concrete_trait_id: ConcreteTraitId<'db>,
1187 lookup_context: ImplLookupContextId<'db>,
1188 ) -> InferenceResult<SolutionSet<'db, ConcreteTraitId<'db>>> {
1189 for negative_impl in &lookup_context.long(self.db).negative_impls {
1190 let generic_param = self
1191 .db
1192 .generic_param_semantic(*negative_impl)
1193 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1194 if let GenericParam::NegImpl(neg_impl) = generic_param
1195 && Ok(concrete_trait_id) == neg_impl.concrete_trait
1196 {
1197 return Ok(SolutionSet::Unique(concrete_trait_id));
1198 }
1199 }
1200
1201 let generic_args = concrete_trait_id.generic_args(self.db);
1202 if generic_args.iter().any(|garg| {
1206 matches!(
1207 garg,
1208 GenericArgumentId::Type(ty)
1209 if !matches!(ty.long(self.db), TypeLongId::Closure(_))
1210 && !ty.is_var_free(self.db)
1211 )
1212 }) {
1213 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1214 }
1215
1216 let mut neg_impl_generic_params = OrderedHashSet::default();
1217 let mut visited_types = OrderedHashSet::default();
1218 for garg in generic_args {
1219 if garg
1220 .extract_generic_params(self.db, &mut neg_impl_generic_params, &mut visited_types)
1221 .is_err()
1222 {
1223 return Ok(SolutionSet::Ambiguous(
1224 Ambiguity::NegativeImplWithUnsupportedExtractedArgs(*garg),
1225 ));
1226 }
1227 }
1228
1229 let solution_set = if neg_impl_generic_params.is_empty() {
1230 self.trait_solution_set(
1231 concrete_trait_id,
1232 ImplVarTraitItemMappings::default(),
1233 lookup_context,
1234 )?
1235 } else {
1236 let mut substitution: OrderedHashMap<GenericParamId<'db>, GenericArgumentId<'db>> =
1237 Default::default();
1238
1239 for param in neg_impl_generic_params {
1240 let garg = match param.kind(self.db) {
1241 GenericKind::Type => GenericArgumentId::Type(
1242 self.new_type_var(Some(param.stable_ptr(self.db).untyped())),
1243 ),
1244 GenericKind::Const | GenericKind::Impl | GenericKind::NegImpl => {
1245 return Ok(SolutionSet::Ambiguous(
1246 Ambiguity::NegativeImplWithUnsupportedGenericParam(param),
1247 ));
1248 }
1249 };
1250
1251 substitution.insert(param, garg);
1252 }
1253 let rewritten_concrete_trait_id =
1254 GenericSubstitution { param_to_arg: substitution.clone(), self_impl: None }
1255 .substitute(self.db, concrete_trait_id)
1256 .unwrap();
1257
1258 let solution_set = self.trait_solution_set(
1259 rewritten_concrete_trait_id,
1260 ImplVarTraitItemMappings::default(),
1261 lookup_context,
1262 )?;
1263
1264 let db = self.db;
1267 for (generic_param, garg) in substitution {
1268 let GenericArgumentId::Type(ty) = garg else {
1269 panic!("Expected a type variable");
1270 };
1271 let TypeLongId::Var(var) = ty.long(self.db) else {
1272 panic!("Expected a type variable");
1273 };
1274 self.type_assignment
1275 .entry(var.id)
1276 .or_insert_with(|| TypeLongId::GenericParameter(generic_param).intern(db));
1277 }
1278
1279 solution_set
1280 };
1281
1282 if !matches!(solution_set, SolutionSet::None) {
1283 return Ok(SolutionSet::None);
1284 }
1285
1286 Ok(SolutionSet::Unique(concrete_trait_id))
1287 }
1288
1289 pub fn trait_solution_set(
1291 &mut self,
1292 concrete_trait_id: ConcreteTraitId<'db>,
1293 impl_var_trait_item_mappings: ImplVarTraitItemMappings<'db>,
1294 lookup_context_id: ImplLookupContextId<'db>,
1295 ) -> InferenceResult<SolutionSet<'db, (CanonicalImpl<'db>, CanonicalMapping<'db>)>> {
1296 let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
1297 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
1299 let mut lookup_context = lookup_context_id.long(self.db).clone();
1300 enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
1301
1302 let generic_args = concrete_trait_id.generic_args(self.db);
1304 match generic_args.first() {
1305 Some(GenericArgumentId::Type(ty)) => {
1306 if let TypeLongId::Var(_) = ty.long(self.db) {
1307 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1309 }
1310 }
1311 Some(GenericArgumentId::Impl(imp)) => {
1312 if let ImplLongId::ImplVar(_) = imp.long(self.db) {
1314 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1315 }
1316 }
1317 Some(GenericArgumentId::Constant(const_value)) => {
1318 if let ConstValue::Var(_, _) = const_value.long(self.db) {
1319 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1321 }
1322 }
1323 _ => {}
1324 };
1325 let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
1326 self.db,
1327 self.inference_id,
1328 concrete_trait_id,
1329 impl_var_trait_item_mappings,
1330 );
1331 let solution_set = match self.db.canonic_trait_solutions(
1334 canonical_trait,
1335 lookup_context.intern(self.db),
1336 (*self.data.impl_type_bounds).clone(),
1337 ) {
1338 Ok(solution_set) => solution_set,
1339 Err(err) => return Err(self.set_error(err)),
1340 };
1341 match solution_set {
1342 SolutionSet::None => Ok(SolutionSet::None),
1343 SolutionSet::Unique(canonical_impl) => {
1344 Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
1345 }
1346 SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
1347 }
1348 }
1349
1350 pub fn set_error(&mut self, err: InferenceError<'db>) -> ErrorSet {
1357 self.set_error_ex(err, None)
1358 }
1359
1360 pub fn set_error_ex(
1364 &mut self,
1365 err: InferenceError<'db>,
1366 stable_ptr: Option<SyntaxStablePtrId<'db>>,
1367 ) -> ErrorSet {
1368 if self.error_status.is_err() {
1369 return ErrorSet;
1370 }
1371 self.error_status = Err(if let InferenceError::Reported(diag_added) = err {
1372 InferenceErrorStatus::Consumed(diag_added)
1373 } else {
1374 InferenceErrorStatus::Pending(PendingInferenceError { err, stable_ptr })
1375 });
1376 ErrorSet
1377 }
1378
1379 pub fn set_error_on_var(&mut self, err: InferenceError<'db>, var: InferenceVar) -> ErrorSet {
1383 self.set_error_ex(err, self.stable_ptrs.get(&var).cloned())
1384 }
1385
1386 pub fn is_error_set(&self) -> InferenceResult<()> {
1388 self.error_status.as_ref().copied().map_err(|_| ErrorSet)
1389 }
1390
1391 fn add_error_stable_ptr(&mut self, var: InferenceVar) {
1393 let var_stable_ptr = self.stable_ptrs.get(&var).copied();
1394 if let Err(InferenceErrorStatus::Pending(PendingInferenceError { err: _, stable_ptr })) =
1395 &mut self.error_status
1396 && stable_ptr.is_none()
1397 {
1398 *stable_ptr = var_stable_ptr;
1399 }
1400 }
1401
1402 pub fn consume_error_without_reporting(
1408 &mut self,
1409 err_set: ErrorSet,
1410 ) -> Option<InferenceError<'db>> {
1411 Some(self.consume_error_inner(err_set, skip_diagnostic())?.err)
1412 }
1413
1414 pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1421 self.consume_error_inner(err_set, diag_added);
1422 }
1423
1424 fn consume_error_inner(
1431 &mut self,
1432 _err_set: ErrorSet,
1433 diag_added: DiagnosticAdded,
1434 ) -> Option<PendingInferenceError<'db>> {
1435 match &mut self.error_status {
1436 Err(InferenceErrorStatus::Pending(error)) => {
1437 let pending_error = std::mem::replace(
1438 error,
1439 PendingInferenceError {
1440 err: InferenceError::Reported(diag_added),
1441 stable_ptr: None,
1442 },
1443 );
1444 self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
1445 Some(pending_error)
1446 }
1447 _ => None,
1449 }
1450 }
1451
1452 pub fn report_on_pending_error(
1458 &mut self,
1459 _err_set: ErrorSet,
1460 diagnostics: &mut SemanticDiagnostics<'db>,
1461 stable_ptr: SyntaxStablePtrId<'db>,
1462 ) -> DiagnosticAdded {
1463 let Err(state_error) = &self.error_status else {
1464 panic!("report_on_pending_error should be called only on error");
1465 };
1466 match state_error {
1467 InferenceErrorStatus::Consumed(diag_added) => *diag_added,
1468 InferenceErrorStatus::Pending(pending) => {
1469 let diag_added = match &pending.err {
1470 InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1471 skip_diagnostic()
1476 }
1477 diag => diag.report(diagnostics, pending.stable_ptr.unwrap_or(stable_ptr)),
1478 };
1479 self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
1480 diag_added
1481 }
1482 }
1483 }
1484
1485 pub fn report_modified_if_pending(
1488 &mut self,
1489 err_set: ErrorSet,
1490 report: impl FnOnce() -> DiagnosticAdded,
1491 ) {
1492 if matches!(self.error_status, Err(InferenceErrorStatus::Pending { .. })) {
1493 self.consume_reported_error(err_set, report());
1494 }
1495 }
1496}
1497
1498impl<'a, 'mt> HasDb<&'a dyn Database> for Inference<'a, 'mt> {
1499 fn get_db(&self) -> &'a dyn Database {
1500 self.db
1501 }
1502}
1503add_basic_rewrites!(<'a, 'mt>, Inference<'a, 'mt>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue NegativeImplLongId NegativeImplId);
1504add_expr_rewrites!(<'a, 'mt>, Inference<'a, 'mt>, NoError, @exclude);
1505add_rewrite!(<'a, 'mt>, Inference<'a, 'mt>, NoError, Ambiguity<'a>);
1506impl<'db, 'mt> SemanticRewriter<TypeId<'db>, NoError> for Inference<'db, 'mt> {
1507 fn internal_rewrite(&mut self, value: &mut TypeId<'db>) -> Result<RewriteResult, NoError> {
1508 if value.is_var_free(self.db) {
1509 return Ok(RewriteResult::NoChange);
1510 }
1511 value.default_rewrite(self)
1512 }
1513}
1514impl<'db, 'mt> SemanticRewriter<ImplId<'db>, NoError> for Inference<'db, 'mt> {
1515 fn internal_rewrite(&mut self, value: &mut ImplId<'db>) -> Result<RewriteResult, NoError> {
1516 if value.is_var_free(self.db) {
1517 return Ok(RewriteResult::NoChange);
1518 }
1519 value.default_rewrite(self)
1520 }
1521}
1522impl<'db, 'mt> SemanticRewriter<NegativeImplId<'db>, NoError> for Inference<'db, 'mt> {
1523 fn internal_rewrite(
1524 &mut self,
1525 value: &mut NegativeImplId<'db>,
1526 ) -> Result<RewriteResult, NoError> {
1527 if value.is_var_free(self.db) {
1528 return Ok(RewriteResult::NoChange);
1529 }
1530 value.default_rewrite(self)
1531 }
1532}
1533
1534impl<'db, 'mt> SemanticRewriter<TypeLongId<'db>, NoError> for Inference<'db, 'mt> {
1535 fn internal_rewrite(&mut self, value: &mut TypeLongId<'db>) -> Result<RewriteResult, NoError> {
1536 match value {
1537 TypeLongId::Var(var) => {
1538 if let Some(type_id) = self.type_assignment.get(&var.id) {
1539 let mut long_type_id = type_id.long(self.db).clone();
1540 if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1541 *self.type_assignment.get_mut(&var.id).unwrap() =
1542 long_type_id.clone().intern(self.db);
1543 }
1544 *value = long_type_id;
1545 return Ok(RewriteResult::Modified);
1546 }
1547 }
1548 TypeLongId::ImplType(impl_type_id) => {
1549 if let Some(type_id) = self.impl_type_bounds.get(&((*impl_type_id).into())) {
1550 *value = type_id.long(self.db).clone();
1551 self.internal_rewrite(value)?;
1552 return Ok(RewriteResult::Modified);
1553 }
1554 let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1555 let impl_id = impl_type_id.impl_id();
1556 let trait_ty = impl_type_id.ty();
1557 return Ok(match impl_id.long(self.db) {
1558 ImplLongId::GenericParameter(_)
1559 | ImplLongId::SelfImpl(_)
1560 | ImplLongId::ImplImpl(_) => impl_type_id_rewrite_result,
1561 ImplLongId::Concrete(_) => {
1562 if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1563 impl_id, trait_ty, self.db,
1564 )) {
1565 *value = self.rewrite(ty).no_err().long(self.db).clone();
1566 RewriteResult::Modified
1567 } else {
1568 impl_type_id_rewrite_result
1569 }
1570 }
1571 ImplLongId::ImplVar(var) => {
1572 *value = self.rewritten_impl_type(*var, trait_ty).long(self.db).clone();
1573 return Ok(RewriteResult::Modified);
1574 }
1575 ImplLongId::GeneratedImpl(generated) => {
1576 *value = self
1577 .rewrite(
1578 *generated
1579 .long(self.db)
1580 .impl_items
1581 .0
1582 .get(&impl_type_id.ty())
1583 .unwrap(),
1584 )
1585 .no_err()
1586 .long(self.db)
1587 .clone();
1588 RewriteResult::Modified
1589 }
1590 });
1591 }
1592 _ => {}
1593 }
1594 value.default_rewrite(self)
1595 }
1596}
1597impl<'db, 'mt> SemanticRewriter<ConstValue<'db>, NoError> for Inference<'db, 'mt> {
1598 fn internal_rewrite(&mut self, value: &mut ConstValue<'db>) -> Result<RewriteResult, NoError> {
1599 match value {
1600 ConstValue::Var(var, _) => {
1601 return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1602 let mut const_value = const_value_id.long(self.db).clone();
1603 if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1604 *self.const_assignment.get_mut(&var.id).unwrap() =
1605 const_value.clone().intern(self.db);
1606 }
1607 *value = const_value;
1608 RewriteResult::Modified
1609 } else {
1610 RewriteResult::NoChange
1611 });
1612 }
1613 ConstValue::ImplConstant(impl_constant_id) => {
1614 let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1615 let impl_id = impl_constant_id.impl_id();
1616 let trait_constant = impl_constant_id.trait_constant_id();
1617 return Ok(match impl_id.long(self.db) {
1618 ImplLongId::GenericParameter(_)
1619 | ImplLongId::SelfImpl(_)
1620 | ImplLongId::GeneratedImpl(_)
1621 | ImplLongId::ImplImpl(_) => impl_constant_id_rewrite_result,
1622 ImplLongId::Concrete(_) => {
1623 if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1624 ImplConstantId::new(impl_id, trait_constant, self.db),
1625 ) {
1626 *value = self.rewrite(constant).no_err().long(self.db).clone();
1627 RewriteResult::Modified
1628 } else {
1629 impl_constant_id_rewrite_result
1630 }
1631 }
1632 ImplLongId::ImplVar(var) => {
1633 *value = self
1634 .rewritten_impl_constant(*var, trait_constant)
1635 .long(self.db)
1636 .clone();
1637 return Ok(RewriteResult::Modified);
1638 }
1639 });
1640 }
1641 _ => {}
1642 }
1643 value.default_rewrite(self)
1644 }
1645}
1646impl<'db, 'mt> SemanticRewriter<ImplLongId<'db>, NoError> for Inference<'db, 'mt> {
1647 fn internal_rewrite(&mut self, value: &mut ImplLongId<'db>) -> Result<RewriteResult, NoError> {
1648 match value {
1649 ImplLongId::ImplVar(var) => {
1650 let long_id = var.long(self.db);
1651 let impl_var_id = long_id.id;
1653 if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1654 let mut long_impl_id = impl_id.long(self.db).clone();
1655 if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1656 *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1657 long_impl_id.clone().intern(self.db);
1658 }
1659 *value = long_impl_id;
1660 return Ok(RewriteResult::Modified);
1661 }
1662 }
1663 ImplLongId::ImplImpl(impl_impl_id) => {
1664 let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1665 let impl_id = impl_impl_id.impl_id();
1666 return Ok(match impl_id.long(self.db) {
1667 ImplLongId::GenericParameter(_)
1668 | ImplLongId::SelfImpl(_)
1669 | ImplLongId::GeneratedImpl(_)
1670 | ImplLongId::ImplImpl(_) => impl_impl_id_rewrite_result,
1671 ImplLongId::Concrete(_) => {
1672 if let Ok(imp) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1673 *value = self.rewrite(imp).no_err().long(self.db).clone();
1674 RewriteResult::Modified
1675 } else {
1676 impl_impl_id_rewrite_result
1677 }
1678 }
1679 ImplLongId::ImplVar(var) => {
1680 if let Ok(concrete_trait_impl) =
1681 impl_impl_id.concrete_trait_impl_id(self.db)
1682 {
1683 *value = self
1684 .rewritten_impl_impl(*var, concrete_trait_impl)
1685 .long(self.db)
1686 .clone();
1687 return Ok(RewriteResult::Modified);
1688 } else {
1689 impl_impl_id_rewrite_result
1690 }
1691 }
1692 });
1693 }
1694
1695 _ => {}
1696 }
1697 if value.is_var_free(self.db) {
1698 return Ok(RewriteResult::NoChange);
1699 }
1700 value.default_rewrite(self)
1701 }
1702}
1703
1704impl<'db, 'mt> SemanticRewriter<NegativeImplLongId<'db>, NoError> for Inference<'db, 'mt> {
1705 fn internal_rewrite(
1706 &mut self,
1707 value: &mut NegativeImplLongId<'db>,
1708 ) -> Result<RewriteResult, NoError> {
1709 if let NegativeImplLongId::NegativeImplVar(var) = value {
1710 let long_id = var.long(self.db);
1711 let neg_impl_var_id = long_id.id;
1713 if let Some(impl_id) = self.negative_impl_assignment(neg_impl_var_id) {
1714 let mut long_neg_impl_id = impl_id.long(self.db).clone();
1715 if let RewriteResult::Modified = self.internal_rewrite(&mut long_neg_impl_id)? {
1716 *self.negative_impl_assignment.get_mut(&neg_impl_var_id).unwrap() =
1717 long_neg_impl_id.clone().intern(self.db);
1718 }
1719 *value = long_neg_impl_id;
1720 return Ok(RewriteResult::Modified);
1721 }
1722 }
1723
1724 if value.is_var_free(self.db) {
1725 return Ok(RewriteResult::NoChange);
1726 }
1727 value.default_rewrite(self)
1728 }
1729}
1730
1731struct InferenceIdReplacer<'a> {
1732 db: &'a dyn Database,
1733 from_inference_id: InferenceId<'a>,
1734 to_inference_id: InferenceId<'a>,
1735}
1736impl<'a> InferenceIdReplacer<'a> {
1737 fn new(
1738 db: &'a dyn Database,
1739 from_inference_id: InferenceId<'a>,
1740 to_inference_id: InferenceId<'a>,
1741 ) -> Self {
1742 Self { db, from_inference_id, to_inference_id }
1743 }
1744}
1745impl<'a> HasDb<&'a dyn Database> for InferenceIdReplacer<'a> {
1746 fn get_db(&self) -> &'a dyn Database {
1747 self.db
1748 }
1749}
1750add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1751add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1752add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity<'a>);
1753impl<'a> SemanticRewriter<InferenceId<'a>, NoError> for InferenceIdReplacer<'a> {
1754 fn internal_rewrite(&mut self, value: &mut InferenceId<'a>) -> Result<RewriteResult, NoError> {
1755 if value == &self.from_inference_id {
1756 *value = self.to_inference_id;
1757 Ok(RewriteResult::Modified)
1758 } else {
1759 Ok(RewriteResult::NoChange)
1760 }
1761 }
1762}