1#![expect(clippy::disallowed_types)]
4
5use std::collections::{BTreeMap, HashMap};
6use std::hash::Hash;
7use std::ops::{Deref, DerefMut};
8use std::sync::Arc;
9
10use cairo_lang_debug::DebugWithDb;
11use cairo_lang_defs::ids::{
12 ConstantId, EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericKind,
13 GenericParamId, GlobalUseId, ImplAliasId, ImplDefId, ImplFunctionId, ImplImplDefId, LocalVarId,
14 LookupItemId, MacroCallId, MemberId, NamedLanguageElementId, ParamId, StructId,
15 TraitConstantId, TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VarId, VariantId,
16};
17use cairo_lang_diagnostics::{DiagnosticAdded, skip_diagnostic};
18use cairo_lang_proc_macros::{DebugWithDb, HeapSize, SemanticObject};
19use cairo_lang_syntax::node::TypedStablePtr;
20use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
21use cairo_lang_utils::deque::Deque;
22use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
23use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
24use cairo_lang_utils::{Intern, define_short_id, extract_matches};
25use salsa::Database;
26
27use self::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, NoError};
28use self::solver::{Ambiguity, SolutionSet, enrich_lookup_context};
29use crate::corelib::{CorelibSemantic, is_valid_literal_type};
30use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
31use crate::expr::inference::canonic::ResultNoErrEx;
32use crate::expr::inference::conform::InferenceConform;
33use crate::expr::inference::solver::SemanticSolver;
34use crate::expr::objects::*;
35use crate::expr::pattern::*;
36use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
37use crate::items::functions::{
38 ConcreteFunctionWithBody, ConcreteFunctionWithBodyId, GenericFunctionId,
39 GenericFunctionWithBodyId, ImplFunctionBodyId, ImplGenericFunctionId,
40 ImplGenericFunctionWithBodyId,
41};
42use crate::items::generics::{
43 GenericParamConst, GenericParamImpl, GenericParamSemantic, GenericParamType,
44};
45use crate::items::imp::{
46 GeneratedImplId, GeneratedImplItems, GeneratedImplLongId, ImplId, ImplImplId, ImplLongId,
47 ImplLookupContext, ImplLookupContextId, ImplSemantic, NegativeImplId, NegativeImplLongId,
48 UninferredGeneratedImplId, UninferredGeneratedImplLongId, UninferredImpl,
49};
50use crate::items::trt::{
51 ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId, ConcreteTraitTypeId,
52 ConcreteTraitTypeLongId,
53};
54use crate::substitution::{GenericSubstitution, HasDb, RewriteResult, SemanticRewriter};
55use crate::types::{
56 ClosureTypeLongId, ConcreteEnumLongId, ConcreteExternTypeLongId, ConcreteStructLongId,
57 ImplTypeById, ImplTypeId, get_impl_at_context,
58};
59use crate::{
60 ConcreteEnumId, ConcreteExternTypeId, ConcreteFunction, ConcreteImplId, ConcreteImplLongId,
61 ConcreteStructId, ConcreteTraitId, ConcreteTraitLongId, ConcreteTypeId, ConcreteVariant,
62 FunctionId, FunctionLongId, GenericArgumentId, GenericParam, LocalVariable, MatchArmSelector,
63 Member, Parameter, SemanticObject, Signature, TypeId, TypeLongId, ValueSelectorArm,
64 add_basic_rewrites, add_expr_rewrites, add_rewrite, semantic_object_for_id,
65};
66
67pub mod canonic;
68pub mod conform;
69pub mod infers;
70pub mod solver;
71
72#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, HeapSize, salsa::Update)]
75pub struct TypeVar<'db> {
76 pub inference_id: InferenceId<'db>,
77 pub id: LocalTypeVarId,
78}
79
80#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, HeapSize, salsa::Update)]
83pub struct ConstVar<'db> {
84 pub inference_id: InferenceId<'db>,
85 pub id: LocalConstVarId,
86}
87
88#[derive(
90 Copy, Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, HeapSize, salsa::Update,
91)]
92#[debug_db(dyn Database)]
93pub enum InferenceId<'db> {
94 LookupItemDeclaration(LookupItemId<'db>),
95 LookupItemGenerics(LookupItemId<'db>),
96 LookupItemDefinition(LookupItemId<'db>),
97 ImplDefTrait(ImplDefId<'db>),
98 ImplAliasImplDef(ImplAliasId<'db>),
99 GenericParam(GenericParamId<'db>),
100 GenericImplParamTrait(GenericParamId<'db>),
101 GlobalUseStar(GlobalUseId<'db>),
102 MacroCall(MacroCallId<'db>),
103 Canonical,
104 NoContext,
106}
107
108#[derive(
111 Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, HeapSize, salsa::Update,
112)]
113#[debug_db(dyn Database)]
114pub struct ImplVar<'db> {
115 pub inference_id: InferenceId<'db>,
116 #[dont_rewrite]
117 pub id: LocalImplVarId,
118 pub concrete_trait_id: ConcreteTraitId<'db>,
119 #[dont_rewrite]
120 pub lookup_context: ImplLookupContextId<'db>,
121}
122impl<'db> ImplVar<'db> {
123 pub fn intern(&self, db: &'db dyn Database) -> ImplVarId<'db> {
124 self.clone().intern(db)
125 }
126}
127
128#[derive(
130 Clone, Debug, PartialEq, Eq, Hash, DebugWithDb, SemanticObject, HeapSize, salsa::Update,
131)]
132#[debug_db(dyn Database)]
133pub struct NegativeImplVar<'db> {
134 pub inference_id: InferenceId<'db>,
135 #[dont_rewrite]
136 pub id: LocalNegativeImplVarId,
137 pub concrete_trait_id: ConcreteTraitId<'db>,
138 #[dont_rewrite]
139 pub lookup_context: ImplLookupContextId<'db>,
140}
141
142#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
143pub struct LocalTypeVarId(pub usize);
144#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
145pub struct LocalImplVarId(pub usize);
146
147#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
148pub struct LocalNegativeImplVarId(pub usize);
149
150#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, SemanticObject, HeapSize, salsa::Update)]
151pub struct LocalConstVarId(pub usize);
152
153define_short_id!(ImplVarId, ImplVar<'db>);
154impl<'db> ImplVarId<'db> {
155 pub fn id(&self, db: &dyn Database) -> LocalImplVarId {
156 self.long(db).id
157 }
158 pub fn concrete_trait_id(&self, db: &'db dyn Database) -> ConcreteTraitId<'db> {
159 self.long(db).concrete_trait_id
160 }
161 pub fn lookup_context(&self, db: &'db dyn Database) -> ImplLookupContextId<'db> {
162 self.long(db).lookup_context
163 }
164}
165semantic_object_for_id!(ImplVarId, ImplVar<'a>);
166
167define_short_id!(NegativeImplVarId, NegativeImplVar<'db>);
168impl<'db> NegativeImplVarId<'db> {
169 pub fn id(&self, db: &dyn Database) -> LocalNegativeImplVarId {
170 self.long(db).id
171 }
172 pub fn concrete_trait_id(&self, db: &'db dyn Database) -> ConcreteTraitId<'db> {
173 self.long(db).concrete_trait_id
174 }
175 pub fn lookup_context(&self, db: &'db dyn Database) -> ImplLookupContextId<'db> {
176 self.long(db).lookup_context
177 }
178}
179semantic_object_for_id!(NegativeImplVarId, NegativeImplVar<'a>);
180
181#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, SemanticObject, salsa::Update)]
182pub enum InferenceVar {
183 Type(LocalTypeVarId),
184 Const(LocalConstVarId),
185 Impl(LocalImplVarId),
186 NegativeImpl(LocalNegativeImplVarId),
187}
188
189#[derive(Clone, Debug, Eq, Hash, PartialEq, DebugWithDb, salsa::Update)]
191#[debug_db(dyn Database)]
192pub enum InferenceError<'db> {
193 Reported(DiagnosticAdded),
195 Cycle(InferenceVar),
196 TypeKindMismatch {
197 ty0: TypeId<'db>,
198 ty1: TypeId<'db>,
199 },
200 ConstKindMismatch {
201 const0: ConstValueId<'db>,
202 const1: ConstValueId<'db>,
203 },
204 ImplKindMismatch {
205 impl0: ImplId<'db>,
206 impl1: ImplId<'db>,
207 },
208 NegativeImplKindMismatch {
209 impl0: NegativeImplId<'db>,
210 impl1: NegativeImplId<'db>,
211 },
212 GenericArgMismatch {
213 garg0: GenericArgumentId<'db>,
214 garg1: GenericArgumentId<'db>,
215 },
216 TraitMismatch {
217 trt0: TraitId<'db>,
218 trt1: TraitId<'db>,
219 },
220 ImplTypeMismatch {
221 impl_id: ImplId<'db>,
222 trait_type_id: TraitTypeId<'db>,
223 ty0: TypeId<'db>,
224 ty1: TypeId<'db>,
225 },
226 GenericFunctionMismatch {
227 func0: GenericFunctionId<'db>,
228 func1: GenericFunctionId<'db>,
229 },
230 ConstNotInferred,
231 NoImplsFound(ConcreteTraitId<'db>),
234 NoNegativeImplsFound(ConcreteTraitId<'db>),
235 Ambiguity(Ambiguity<'db>),
236 TypeNotInferred(TypeId<'db>),
237 InvalidLiteralType(TypeId<'db>),
239}
240impl<'db> InferenceError<'db> {
241 pub fn format(&self, db: &dyn Database) -> String {
242 match self {
243 InferenceError::Reported(_) => "Inference error occurred.".into(),
244 InferenceError::Cycle(_var) => "Inference cycle detected".into(),
245 InferenceError::TypeKindMismatch { ty0, ty1 } => {
246 format!("Type mismatch: `{:?}` and `{:?}`.", ty0.debug(db), ty1.debug(db))
247 }
248 InferenceError::ConstKindMismatch { const0, const1 } => {
249 format!("Const mismatch: `{:?}` and `{:?}`.", const0.debug(db), const1.debug(db))
250 }
251 InferenceError::ImplKindMismatch { impl0, impl1 } => {
252 format!("Impl mismatch: `{:?}` and `{:?}`.", impl0.debug(db), impl1.debug(db))
253 }
254 InferenceError::NegativeImplKindMismatch { impl0, impl1 } => {
255 format!(
256 "Negative impl mismatch: `{:?}` and `{:?}`.",
257 impl0.debug(db),
258 impl1.debug(db)
259 )
260 }
261 InferenceError::GenericArgMismatch { garg0, garg1 } => {
262 format!(
263 "Generic arg mismatch: `{:?}` and `{:?}`.",
264 garg0.debug(db),
265 garg1.debug(db)
266 )
267 }
268 InferenceError::TraitMismatch { trt0, trt1 } => {
269 format!("Trait mismatch: `{:?}` and `{:?}`.", trt0.debug(db), trt1.debug(db))
270 }
271 InferenceError::ConstNotInferred => "Failed to infer constant.".into(),
272 InferenceError::NoImplsFound(concrete_trait_id) => {
273 if concrete_trait_id.trait_id(db) == db.core_info().string_literal_trt {
274 let generic_type = extract_matches!(
275 concrete_trait_id.generic_args(db)[0],
276 GenericArgumentId::Type
277 );
278 return format!(
279 "Mismatched types. The type `{:?}` cannot be created from a string \
280 literal.",
281 generic_type.debug(db)
282 );
283 }
284 format!(
285 "Trait has no implementation in context: {:?}.",
286 concrete_trait_id.debug(db)
287 )
288 }
289 InferenceError::NoNegativeImplsFound(concrete_trait_id) => {
290 format!("Trait has implementation in context: {:?}.", concrete_trait_id.debug(db))
291 }
292 InferenceError::Ambiguity(ambiguity) => ambiguity.format(db),
293 InferenceError::TypeNotInferred(ty) => {
294 format!("Type annotations needed. Failed to infer {:?}.", ty.debug(db))
295 }
296 InferenceError::GenericFunctionMismatch { func0, func1 } => {
297 format!("Function mismatch: `{}` and `{}`.", func0.format(db), func1.format(db))
298 }
299 InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 } => {
300 format!(
301 "`{}::{}` type mismatch: `{:?}` and `{:?}`.",
302 impl_id.format(db),
303 trait_type_id.name(db).long(db),
304 ty0.debug(db),
305 ty1.debug(db)
306 )
307 }
308 InferenceError::InvalidLiteralType(ty) => format!(
309 "Mismatched types. The type `{:?}` cannot be created from a numeric literal.",
310 ty.debug(db),
311 ),
312 }
313 }
314}
315
316impl<'db> InferenceError<'db> {
317 pub fn report(
318 &self,
319 diagnostics: &mut SemanticDiagnostics<'db>,
320 stable_ptr: SyntaxStablePtrId<'db>,
321 ) -> DiagnosticAdded {
322 match self {
323 InferenceError::Reported(diagnostic_added) => *diagnostic_added,
324 _ => diagnostics
325 .report(stable_ptr, SemanticDiagnosticKind::InternalInferenceError(self.clone())),
326 }
327 }
328}
329
330#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
335pub struct ErrorSet;
336
337pub type InferenceResult<T> = Result<T, ErrorSet>;
338
339#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
340enum InferenceErrorStatus<'db> {
341 Pending(PendingInferenceError<'db>),
343 Consumed(DiagnosticAdded),
345}
346
347#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
349struct PendingInferenceError<'db> {
350 err: InferenceError<'db>,
352 stable_ptr: Option<SyntaxStablePtrId<'db>>,
354}
355
356#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, SemanticObject, salsa::Update)]
358pub struct ImplVarTraitItemMappings<'db> {
359 types: OrderedHashMap<TraitTypeId<'db>, TypeId<'db>>,
361 constants: OrderedHashMap<TraitConstantId<'db>, ConstValueId<'db>>,
363 impls: OrderedHashMap<TraitImplId<'db>, ImplId<'db>>,
365}
366
367impl ImplVarTraitItemMappings<'_> {
368 pub fn is_empty(&self) -> bool {
370 self.types.is_empty() && self.constants.is_empty() && self.impls.is_empty()
371 }
372}
373
374#[derive(Debug, DebugWithDb, PartialEq, Eq, salsa::Update)]
376#[debug_db(dyn Database)]
377pub struct InferenceData<'db> {
378 pub inference_id: InferenceId<'db>,
379 pub type_assignment: OrderedHashMap<LocalTypeVarId, TypeId<'db>>,
381 pub const_assignment: OrderedHashMap<LocalConstVarId, ConstValueId<'db>>,
383 pub impl_assignment: OrderedHashMap<LocalImplVarId, ImplId<'db>>,
385 pub negative_impl_assignment: OrderedHashMap<LocalNegativeImplVarId, NegativeImplId<'db>>,
387 pub impl_vars_trait_item_mappings: HashMap<LocalImplVarId, ImplVarTraitItemMappings<'db>>,
390 pub type_vars: Vec<TypeVar<'db>>,
392 pub integer_literal_vars: Vec<LocalTypeVarId>,
396 pub const_vars: Vec<ConstVar<'db>>,
398 pub impl_vars: Vec<ImplVar<'db>>,
400 pub negative_impl_vars: Vec<NegativeImplVar<'db>>,
402 pub stable_ptrs: HashMap<InferenceVar, SyntaxStablePtrId<'db>>,
404 pending: Deque<LocalImplVarId>,
406 negative_pending: Deque<LocalNegativeImplVarId>,
408 refuted: Vec<LocalImplVarId>,
410 negative_refuted: Vec<LocalNegativeImplVarId>,
412 solved: Vec<LocalImplVarId>,
414 ambiguous: Vec<(LocalImplVarId, Ambiguity<'db>)>,
416 negative_ambiguous: Vec<(LocalNegativeImplVarId, Ambiguity<'db>)>,
418 pub impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
420
421 error_status: Result<(), InferenceErrorStatus<'db>>,
423}
424impl<'db> InferenceData<'db> {
425 pub fn new(inference_id: InferenceId<'db>) -> Self {
426 Self {
427 inference_id,
428 type_assignment: OrderedHashMap::default(),
429 impl_assignment: OrderedHashMap::default(),
430 const_assignment: OrderedHashMap::default(),
431 negative_impl_assignment: OrderedHashMap::default(),
432 impl_vars_trait_item_mappings: HashMap::new(),
433 type_vars: Vec::new(),
434 integer_literal_vars: Vec::new(),
435 impl_vars: Vec::new(),
436 const_vars: Vec::new(),
437 negative_impl_vars: Vec::new(),
438 stable_ptrs: HashMap::new(),
439 pending: Deque::new(),
440 negative_pending: Deque::new(),
441 refuted: Vec::new(),
442 negative_refuted: Vec::new(),
443 solved: Vec::new(),
444 ambiguous: Vec::new(),
445 negative_ambiguous: Vec::new(),
446 impl_type_bounds: Default::default(),
447 error_status: Ok(()),
448 }
449 }
450 pub fn inference<'r>(&'r mut self, db: &'db dyn Database) -> Inference<'db, 'r> {
451 Inference::new(db, self)
452 }
453 pub fn clone_with_inference_id(
454 &self,
455 db: &'db dyn Database,
456 inference_id: InferenceId<'db>,
457 ) -> InferenceData<'db> {
458 let mut inference_id_replacer =
459 InferenceIdReplacer::new(db, self.inference_id, inference_id);
460 Self {
461 inference_id,
462 type_assignment: self
463 .type_assignment
464 .iter()
465 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
466 .collect(),
467 const_assignment: self
468 .const_assignment
469 .iter()
470 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
471 .collect(),
472 impl_assignment: self
473 .impl_assignment
474 .iter()
475 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
476 .collect(),
477 negative_impl_assignment: self
478 .negative_impl_assignment
479 .iter()
480 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
481 .collect(),
482 impl_vars_trait_item_mappings: self
483 .impl_vars_trait_item_mappings
484 .iter()
485 .map(|(k, mappings)| {
486 (
487 *k,
488 ImplVarTraitItemMappings {
489 types: mappings
490 .types
491 .iter()
492 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
493 .collect(),
494 constants: mappings
495 .constants
496 .iter()
497 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
498 .collect(),
499 impls: mappings
500 .impls
501 .iter()
502 .map(|(k, v)| (*k, inference_id_replacer.rewrite(*v).no_err()))
503 .collect(),
504 },
505 )
506 })
507 .collect(),
508 type_vars: inference_id_replacer.rewrite(self.type_vars.clone()).no_err(),
509 integer_literal_vars: self.integer_literal_vars.clone(),
510 const_vars: inference_id_replacer.rewrite(self.const_vars.clone()).no_err(),
511 impl_vars: inference_id_replacer.rewrite(self.impl_vars.clone()).no_err(),
512 negative_impl_vars: inference_id_replacer
513 .rewrite(self.negative_impl_vars.clone())
514 .no_err(),
515 stable_ptrs: self.stable_ptrs.clone(),
516 pending: inference_id_replacer.rewrite(self.pending.clone()).no_err(),
517 negative_pending: inference_id_replacer.rewrite(self.negative_pending.clone()).no_err(),
518 refuted: inference_id_replacer.rewrite(self.refuted.clone()).no_err(),
519 negative_refuted: inference_id_replacer.rewrite(self.negative_refuted.clone()).no_err(),
520 solved: inference_id_replacer.rewrite(self.solved.clone()).no_err(),
521 ambiguous: inference_id_replacer.rewrite(self.ambiguous.clone()).no_err(),
522 negative_ambiguous: inference_id_replacer
523 .rewrite(self.negative_ambiguous.clone())
524 .no_err(),
525 impl_type_bounds: self.impl_type_bounds.clone(),
527
528 error_status: self.error_status.clone(),
529 }
530 }
531 pub fn temporary_clone(&self) -> InferenceData<'db> {
532 Self {
533 inference_id: self.inference_id,
534 type_assignment: self.type_assignment.clone(),
535 const_assignment: self.const_assignment.clone(),
536 impl_assignment: self.impl_assignment.clone(),
537 negative_impl_assignment: self.negative_impl_assignment.clone(),
538 impl_vars_trait_item_mappings: self.impl_vars_trait_item_mappings.clone(),
539 type_vars: self.type_vars.clone(),
540 integer_literal_vars: self.integer_literal_vars.clone(),
541 const_vars: self.const_vars.clone(),
542 impl_vars: self.impl_vars.clone(),
543 negative_impl_vars: self.negative_impl_vars.clone(),
544 stable_ptrs: self.stable_ptrs.clone(),
545 pending: self.pending.clone(),
546 negative_pending: self.negative_pending.clone(),
547 refuted: self.refuted.clone(),
548 negative_refuted: self.negative_refuted.clone(),
549 solved: self.solved.clone(),
550 ambiguous: self.ambiguous.clone(),
551 negative_ambiguous: self.negative_ambiguous.clone(),
552 impl_type_bounds: self.impl_type_bounds.clone(),
553 error_status: self.error_status.clone(),
554 }
555 }
556}
557
558pub struct Inference<'db, 'id> {
560 db: &'db dyn Database,
561 pub data: &'id mut InferenceData<'db>,
562}
563
564impl<'db, 'id> Deref for Inference<'db, 'id> {
565 type Target = InferenceData<'db>;
566
567 fn deref(&self) -> &Self::Target {
568 self.data
569 }
570}
571impl DerefMut for Inference<'_, '_> {
572 fn deref_mut(&mut self) -> &mut Self::Target {
573 self.data
574 }
575}
576
577impl std::fmt::Debug for Inference<'_, '_> {
578 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
579 let x = self.data.debug(self.db);
580 write!(f, "{x:?}")
581 }
582}
583
584impl<'db, 'id> Inference<'db, 'id> {
585 fn new(db: &'db dyn Database, data: &'id mut InferenceData<'db>) -> Self {
586 Self { db, data }
587 }
588
589 fn impl_var(&self, var_id: LocalImplVarId) -> &ImplVar<'db> {
591 &self.impl_vars[var_id.0]
592 }
593
594 pub fn impl_assignment(&self, var_id: LocalImplVarId) -> Option<ImplId<'db>> {
596 self.impl_assignment.get(&var_id).copied()
597 }
598
599 fn negative_impl_var(&self, var_id: LocalNegativeImplVarId) -> &NegativeImplVar<'db> {
601 &self.negative_impl_vars[var_id.0]
602 }
603
604 pub fn negative_impl_assignment(
606 &self,
607 var_id: LocalNegativeImplVarId,
608 ) -> Option<NegativeImplId<'db>> {
609 self.negative_impl_assignment.get(&var_id).copied()
610 }
611
612 fn type_assignment(&self, var_id: LocalTypeVarId) -> Option<TypeId<'db>> {
614 self.type_assignment.get(&var_id).copied()
615 }
616
617 pub fn new_type_var(&mut self, stable_ptr: Option<SyntaxStablePtrId<'db>>) -> TypeId<'db> {
620 let var = self.new_type_var_raw(stable_ptr);
621
622 TypeLongId::Var(var).intern(self.db)
623 }
624
625 pub fn new_integer_literal_type(
630 &mut self,
631 stable_ptr: Option<SyntaxStablePtrId<'db>>,
632 ) -> TypeId<'db> {
633 let var = self.new_type_var_raw(stable_ptr);
634 self.data.integer_literal_vars.push(var.id);
635 TypeLongId::NumericLiteral(var).intern(self.db)
636 }
637
638 pub fn new_type_var_raw(&mut self, stable_ptr: Option<SyntaxStablePtrId<'db>>) -> TypeVar<'db> {
641 let var =
642 TypeVar { inference_id: self.inference_id, id: LocalTypeVarId(self.type_vars.len()) };
643 if let Some(stable_ptr) = stable_ptr {
644 self.stable_ptrs.insert(InferenceVar::Type(var.id), stable_ptr);
645 }
646 self.type_vars.push(var);
647 var
648 }
649
650 pub fn set_impl_type_bounds(
653 &mut self,
654 impl_type_bounds: OrderedHashMap<ImplTypeId<'db>, TypeId<'db>>,
655 ) {
656 let impl_type_bounds_finalized = impl_type_bounds
657 .iter()
658 .filter_map(|(impl_type, ty)| {
659 let rewritten_type = self.rewrite(ty.long(self.db).clone()).no_err();
660 if !matches!(rewritten_type, TypeLongId::Var(_) | TypeLongId::NumericLiteral(_)) {
661 return Some(((*impl_type).into(), rewritten_type.intern(self.db)));
662 }
663 self.conform_ty(*ty, TypeLongId::ImplType(*impl_type).intern(self.db)).ok();
666 None
667 })
668 .collect();
669
670 self.data.impl_type_bounds = Arc::new(impl_type_bounds_finalized);
671 }
672
673 pub fn new_const_var(
676 &mut self,
677 stable_ptr: Option<SyntaxStablePtrId<'db>>,
678 ty: TypeId<'db>,
679 ) -> ConstValueId<'db> {
680 let var = self.new_const_var_raw(stable_ptr);
681 ConstValue::Var(var, ty).intern(self.db)
682 }
683
684 pub fn new_const_var_raw(
687 &mut self,
688 stable_ptr: Option<SyntaxStablePtrId<'db>>,
689 ) -> ConstVar<'db> {
690 let var = ConstVar {
691 inference_id: self.inference_id,
692 id: LocalConstVarId(self.const_vars.len()),
693 };
694 if let Some(stable_ptr) = stable_ptr {
695 self.stable_ptrs.insert(InferenceVar::Const(var.id), stable_ptr);
696 }
697 self.const_vars.push(var);
698 var
699 }
700
701 pub fn new_impl_var(
704 &mut self,
705 concrete_trait_id: ConcreteTraitId<'db>,
706 stable_ptr: Option<SyntaxStablePtrId<'db>>,
707 lookup_context: ImplLookupContextId<'db>,
708 ) -> ImplId<'db> {
709 let var = self.new_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
710 ImplLongId::ImplVar(self.impl_var(var).intern(self.db)).intern(self.db)
711 }
712
713 fn new_impl_var_raw(
716 &mut self,
717 lookup_context: ImplLookupContextId<'db>,
718 concrete_trait_id: ConcreteTraitId<'db>,
719 stable_ptr: Option<SyntaxStablePtrId<'db>>,
720 ) -> LocalImplVarId {
721 let id = LocalImplVarId(self.impl_vars.len());
722 if let Some(stable_ptr) = stable_ptr {
723 self.stable_ptrs.insert(InferenceVar::Impl(id), stable_ptr);
724 }
725 let var =
726 ImplVar { inference_id: self.inference_id, id, concrete_trait_id, lookup_context };
727 self.impl_vars.push(var);
728 self.pending.push_back(id);
729 id
730 }
731
732 pub fn new_negative_impl_var(
735 &mut self,
736 concrete_trait_id: ConcreteTraitId<'db>,
737 stable_ptr: Option<SyntaxStablePtrId<'db>>,
738 lookup_context: ImplLookupContextId<'db>,
739 ) -> NegativeImplId<'db> {
740 let var = self.new_negative_impl_var_raw(lookup_context, concrete_trait_id, stable_ptr);
741 NegativeImplLongId::NegativeImplVar(self.negative_impl_var(var).clone().intern(self.db))
742 .intern(self.db)
743 }
744
745 fn new_negative_impl_var_raw(
748 &mut self,
749 lookup_context: ImplLookupContextId<'db>,
750 concrete_trait_id: ConcreteTraitId<'db>,
751 stable_ptr: Option<SyntaxStablePtrId<'db>>,
752 ) -> LocalNegativeImplVarId {
753 let id = LocalNegativeImplVarId(self.negative_impl_vars.len());
754 if let Some(stable_ptr) = stable_ptr {
755 self.stable_ptrs.insert(InferenceVar::NegativeImpl(id), stable_ptr);
756 }
757 let var = NegativeImplVar {
758 inference_id: self.inference_id,
759 id,
760 concrete_trait_id,
761 lookup_context,
762 };
763 self.negative_impl_vars.push(var);
764 self.negative_pending.push_back(id);
765 id
766 }
767
768 pub fn solve(&mut self) -> InferenceResult<()> {
773 let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
774 self.pending.extend(ambiguous.map(|(var, _)| var));
775 while let Some(var) = self.pending.pop_front() {
776 self.solve_single_pending(var).inspect_err(|_err_set| {
778 self.add_error_stable_ptr(InferenceVar::Impl(var));
779 })?;
780 }
781 while let Some(var) = self.negative_pending.pop_front() {
782 self.solve_single_negative_pending(var).inspect_err(|_err_set| {
784 self.add_error_stable_ptr(InferenceVar::NegativeImpl(var));
785 })?;
786 }
787 Ok(())
788 }
789
790 fn solve_single_pending(&mut self, var: LocalImplVarId) -> InferenceResult<()> {
791 if self.impl_assignment.contains_key(&var) {
792 return Ok(());
793 }
794 let solution = match self.impl_var_solution_set(var)? {
795 SolutionSet::None => {
796 self.refuted.push(var);
797 return Ok(());
798 }
799 SolutionSet::Ambiguous(ambiguity) => {
800 self.ambiguous.push((var, ambiguity));
801 return Ok(());
802 }
803 SolutionSet::Unique(solution) => solution,
804 };
805
806 self.assign_local_impl(var, solution)?;
808
809 self.solved.push(var);
811 let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
812 self.pending.extend(ambiguous.map(|(var, _)| var));
813
814 let negative_ambiguous = std::mem::take(&mut self.negative_ambiguous).into_iter();
815 self.negative_pending.extend(negative_ambiguous.map(|(var, _)| var));
816
817 Ok(())
818 }
819
820 fn solve_single_negative_pending(
822 &mut self,
823 var: LocalNegativeImplVarId,
824 ) -> InferenceResult<()> {
825 if self.negative_impl_assignment.contains_key(&var) {
826 return Ok(());
827 }
828
829 let solution = match self.negative_impl_var_solution_set(var)? {
830 SolutionSet::None => {
831 self.negative_refuted.push(var);
832 return Ok(());
833 }
834 SolutionSet::Ambiguous(ambiguity) => {
835 self.negative_ambiguous.push((var, ambiguity));
836 return Ok(());
837 }
838 SolutionSet::Unique(solution) => solution,
839 };
840
841 self.assign_local_negative_impl(var, solution)?;
843
844 Ok(())
845 }
846
847 pub fn solution_set(&mut self) -> InferenceResult<SolutionSet<'db, ()>> {
850 self.solve()?;
851 if !self.refuted.is_empty() {
852 return Ok(SolutionSet::None);
853 }
854 if !self.negative_refuted.is_empty() {
855 return Ok(SolutionSet::None);
856 }
857 if let Some((_, ambiguity)) = self.ambiguous.first() {
858 return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
859 }
860 if let Some((_, ambiguity)) = self.negative_ambiguous.first() {
861 return Ok(SolutionSet::Ambiguous(ambiguity.clone()));
862 }
863 assert!(self.pending.is_empty(), "solution() called on an unsolved solver");
864 assert!(self.negative_pending.is_empty(), "solution() called on an unsolved solver");
865 Ok(SolutionSet::Unique(()))
866 }
867
868 pub fn finalize_without_reporting(&mut self) -> Result<(), ErrorSet> {
871 if self.error_status.is_err() {
872 return Err(ErrorSet);
873 }
874 let felt_ty = self.db.core_info().felt252;
875
876 loop {
883 self.solve()?;
884 let mut changed = false;
885 for var_id in self.data.integer_literal_vars.clone() {
886 if self.type_assignment.contains_key(&var_id) {
887 continue;
888 }
889 let var = self.type_vars[var_id.0];
890 let placeholder = TypeLongId::NumericLiteral(var).intern(self.db);
891 if self.rewrite(placeholder).no_err() == felt_ty {
892 continue;
893 }
894 self.conform_ty(placeholder, felt_ty).inspect_err(|_err_set| {
895 self.add_error_stable_ptr(InferenceVar::Type(var_id));
896 })?;
897 changed = true;
898 break;
899 }
900 if !changed {
901 break;
902 }
903 }
904 assert!(
905 self.pending.is_empty(),
906 "pending should all be solved by this point. Guaranteed by solve()."
907 );
908
909 for var_id in self.data.integer_literal_vars.clone() {
912 let Some(bound) = self.type_assignment.get(&var_id).copied() else {
913 continue;
914 };
915 let bound = self.rewrite(bound).no_err();
916 if !bound.is_var_free(self.db) || is_valid_literal_type(self.db, bound) {
917 continue;
918 }
919 let err = self.set_error(InferenceError::InvalidLiteralType(bound));
920 self.add_error_stable_ptr(InferenceVar::Type(var_id));
921 return Err(err);
922 }
923
924 let assignments: Vec<_> = self.impl_assignment.iter().map(|(k, v)| (*k, *v)).collect();
927 for (var_id, impl_id) in assignments {
928 let rewritten = self.rewrite(impl_id).no_err();
929 if rewritten != impl_id {
930 self.impl_assignment.insert(var_id, rewritten);
931 }
932 }
933
934 let Some((var, err)) = self.first_undetermined_variable() else {
935 return Ok(());
936 };
937 Err(self.set_error_on_var(err, var))
938 }
939
940 pub fn finalize<'m>(
944 &'m mut self,
945 diagnostics: &mut SemanticDiagnostics<'db>,
946 stable_ptr: SyntaxStablePtrId<'db>,
947 ) {
948 if let Err(err_set) = self.finalize_without_reporting() {
949 let diag = self.report_on_pending_error(err_set, diagnostics, stable_ptr);
950
951 let ty_missing = TypeId::missing(self.db, diag);
952 for var in &self.data.type_vars {
953 self.data.type_assignment.entry(var.id).or_insert(ty_missing);
954 }
955 }
956 }
957
958 fn first_undetermined_variable(&mut self) -> Option<(InferenceVar, InferenceError<'db>)> {
962 if let Some(var) = self.refuted.first().copied() {
963 let impl_var = self.impl_var(var).clone();
964 let concrete_trait_id = impl_var.concrete_trait_id;
965 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
966 return Some((
967 InferenceVar::Impl(var),
968 InferenceError::NoImplsFound(concrete_trait_id),
969 ));
970 }
971 if let Some(var) = self.negative_refuted.first().copied() {
972 let negative_impl_var = self.negative_impl_var(var).clone();
973 let concrete_trait_id = negative_impl_var.concrete_trait_id;
974 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
975 return Some((
976 InferenceVar::NegativeImpl(var),
977 InferenceError::NoNegativeImplsFound(concrete_trait_id),
978 ));
979 }
980
981 let mut fallback_ret = None;
982 if let Some((var, ambiguity)) = self.ambiguous.first() {
983 let ret =
985 Some((InferenceVar::Impl(*var), InferenceError::Ambiguity(ambiguity.clone())));
986 if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
987 return ret;
988 } else {
989 fallback_ret = ret;
990 }
991 }
992 if let Some((var, ambiguity)) = self.negative_ambiguous.first() {
993 let ret = Some((
994 InferenceVar::NegativeImpl(*var),
995 InferenceError::Ambiguity(ambiguity.clone()),
996 ));
997 if !matches!(ambiguity, Ambiguity::WillNotInfer(_)) {
998 return ret;
999 } else {
1000 fallback_ret = ret;
1001 }
1002 }
1003 for (id, var) in self.type_vars.iter().enumerate() {
1004 if self.type_assignment(LocalTypeVarId(id)).is_none() {
1005 let ty = TypeLongId::Var(*var).intern(self.db);
1006 return Some((InferenceVar::Type(var.id), InferenceError::TypeNotInferred(ty)));
1007 }
1008 }
1009 for (id, var) in self.const_vars.iter().enumerate() {
1010 if !self.const_assignment.contains_key(&LocalConstVarId(id)) {
1011 let infernence_var = InferenceVar::Const(var.id);
1012 return Some((infernence_var, InferenceError::ConstNotInferred));
1013 }
1014 }
1015 fallback_ret
1016 }
1017
1018 fn assign_local_impl(
1020 &mut self,
1021 var: LocalImplVarId,
1022 impl_id: ImplId<'db>,
1023 ) -> InferenceResult<ImplId<'db>> {
1024 let concrete_trait = impl_id
1025 .concrete_trait(self.db)
1026 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1027 self.conform_traits(self.impl_var(var).concrete_trait_id, concrete_trait)?;
1028 if let Some(other_impl) = self.impl_assignment(var) {
1029 return self.conform_impl(impl_id, other_impl);
1030 }
1031 if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
1032 {
1033 let inference_var = InferenceVar::Impl(var);
1034 return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var));
1035 }
1036 self.impl_assignment.insert(var, impl_id);
1037 if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
1038 for (trait_type_id, ty) in mappings.types {
1039 let impl_ty = self
1040 .db
1041 .impl_type_concrete_implized(ImplTypeId::new(impl_id, trait_type_id, self.db))
1042 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1043 if let Err(err_set) = self.conform_ty(ty, impl_ty) {
1044 let ty0 = self.rewrite(ty).no_err();
1046 let ty1 = self.rewrite(impl_ty).no_err();
1047
1048 let err = InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 };
1049 self.error_status = Err(InferenceErrorStatus::Pending(PendingInferenceError {
1050 err,
1051 stable_ptr: self.stable_ptrs.get(&InferenceVar::Impl(var)).cloned(),
1052 }));
1053 return Err(err_set);
1054 }
1055 }
1056 for (trait_constant, constant_id) in mappings.constants {
1057 let concrete_impl_constant = self
1058 .db
1059 .impl_constant_concrete_implized_value(ImplConstantId::new(
1060 impl_id,
1061 trait_constant,
1062 self.db,
1063 ))
1064 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1065 self.conform_const(constant_id, concrete_impl_constant)?;
1066 }
1067 for (trait_impl, inner_impl_id) in mappings.impls {
1068 let concrete_impl_impl = self
1069 .db
1070 .impl_impl_concrete_implized(ImplImplId::new(impl_id, trait_impl, self.db))
1071 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1072 self.conform_impl(inner_impl_id, concrete_impl_impl)?;
1073 }
1074 }
1075 Ok(impl_id)
1076 }
1077
1078 fn assign_impl(
1080 &mut self,
1081 var_id: ImplVarId<'db>,
1082 impl_id: ImplId<'db>,
1083 ) -> InferenceResult<ImplId<'db>> {
1084 let var = var_id.long(self.db);
1085 if var.inference_id != self.inference_id {
1086 return Err(self.set_error(InferenceError::ImplKindMismatch {
1087 impl0: ImplLongId::ImplVar(var_id).intern(self.db),
1088 impl1: impl_id,
1089 }));
1090 }
1091 self.assign_local_impl(var.id, impl_id)
1092 }
1093
1094 fn assign_local_negative_impl(
1096 &mut self,
1097 var: LocalNegativeImplVarId,
1098 neg_impl_id: NegativeImplId<'db>,
1099 ) -> InferenceResult<NegativeImplId<'db>> {
1100 let concrete_trait = neg_impl_id
1101 .concrete_trait(self.db)
1102 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1103 self.conform_traits(self.negative_impl_var(var).concrete_trait_id, concrete_trait)?;
1104 if let Some(other_impl) = self.negative_impl_assignment(var) {
1105 return self.conform_neg_impl(neg_impl_id, other_impl);
1106 }
1107 if !neg_impl_id.is_var_free(self.db)
1108 && self.negative_impl_contains_var(neg_impl_id, InferenceVar::NegativeImpl(var))
1109 {
1110 return Err(self.set_error(InferenceError::Cycle(InferenceVar::NegativeImpl(var))));
1111 }
1112 self.negative_impl_assignment.insert(var, neg_impl_id);
1113 Ok(neg_impl_id)
1114 }
1115
1116 fn assign_neg_impl(
1119 &mut self,
1120 var_id: NegativeImplVarId<'db>,
1121 neg_impl_id: NegativeImplId<'db>,
1122 ) -> InferenceResult<NegativeImplId<'db>> {
1123 let var = var_id.long(self.db);
1124 if var.inference_id != self.inference_id {
1125 return Err(self.set_error(InferenceError::NegativeImplKindMismatch {
1126 impl0: NegativeImplLongId::NegativeImplVar(var_id).intern(self.db),
1127 impl1: neg_impl_id,
1128 }));
1129 }
1130 self.assign_local_negative_impl(var.id, neg_impl_id)
1131 }
1132
1133 fn assign_ty(&mut self, var: TypeVar<'db>, ty: TypeId<'db>) -> InferenceResult<TypeId<'db>> {
1136 if var.inference_id != self.inference_id {
1137 return Err(self.set_error(InferenceError::TypeKindMismatch {
1138 ty0: TypeLongId::Var(var).intern(self.db),
1139 ty1: ty,
1140 }));
1141 }
1142 assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
1143 let inference_var = InferenceVar::Type(var.id);
1144 if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
1145 return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var));
1146 }
1147 if let TypeLongId::Var(other) = ty.long(self.db)
1149 && other.inference_id == self.inference_id
1150 && other.id.0 > var.id.0
1151 {
1152 let var_ty = TypeLongId::Var(var).intern(self.db);
1153 self.type_assignment.insert(other.id, var_ty);
1154 return Ok(var_ty);
1155 }
1156 self.type_assignment.insert(var.id, ty);
1157 Ok(ty)
1158 }
1159
1160 fn assign_const(
1163 &mut self,
1164 var: ConstVar<'db>,
1165 id: ConstValueId<'db>,
1166 ) -> InferenceResult<ConstValueId<'db>> {
1167 if var.inference_id != self.inference_id {
1168 return Err(self.set_error(InferenceError::ConstKindMismatch {
1169 const0: ConstValue::Var(var, TypeId::missing(self.db, skip_diagnostic()))
1170 .intern(self.db),
1171 const1: id,
1172 }));
1173 }
1174
1175 self.const_assignment.insert(var.id, id);
1176 Ok(id)
1177 }
1178
1179 fn impl_var_solution_set(
1181 &mut self,
1182 var: LocalImplVarId,
1183 ) -> InferenceResult<SolutionSet<'db, ImplId<'db>>> {
1184 let impl_var = self.impl_var(var).clone();
1185 let concrete_trait_id = self.rewrite(impl_var.concrete_trait_id).no_err();
1187 self.impl_vars[impl_var.id.0].concrete_trait_id = concrete_trait_id;
1188 let impl_var_trait_item_mappings =
1189 self.impl_vars_trait_item_mappings.get(&var).cloned().unwrap_or_default();
1190 let solution_set = self.trait_solution_set(
1191 concrete_trait_id,
1192 impl_var_trait_item_mappings,
1193 impl_var.lookup_context,
1194 )?;
1195 Ok(match solution_set {
1196 SolutionSet::None => SolutionSet::None,
1197 SolutionSet::Unique((canonical_impl, canonicalizer)) => {
1198 SolutionSet::Unique(canonical_impl.embed(self, &canonicalizer))
1199 }
1200 SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
1201 })
1202 }
1203
1204 fn negative_impl_var_solution_set(
1206 &mut self,
1207 var: LocalNegativeImplVarId,
1208 ) -> InferenceResult<SolutionSet<'db, NegativeImplId<'db>>> {
1209 let negative_impl_var = self.negative_impl_var(var).clone();
1210 let concrete_trait_id = self.rewrite(negative_impl_var.concrete_trait_id).no_err();
1211
1212 let solution_set =
1213 self.validate_no_solution_set(concrete_trait_id, negative_impl_var.lookup_context)?;
1214 Ok(match solution_set {
1215 SolutionSet::Unique(concrete_trait_id) => {
1216 SolutionSet::Unique(NegativeImplLongId::Solved(concrete_trait_id).intern(self.db))
1217 }
1218 SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
1219 SolutionSet::None => SolutionSet::None,
1220 })
1221 }
1222
1223 fn validate_no_solution_set(
1225 &mut self,
1226 concrete_trait_id: ConcreteTraitId<'db>,
1227 lookup_context: ImplLookupContextId<'db>,
1228 ) -> InferenceResult<SolutionSet<'db, ConcreteTraitId<'db>>> {
1229 for negative_impl in &lookup_context.long(self.db).negative_impls {
1230 let generic_param = self
1231 .db
1232 .generic_param_semantic(*negative_impl)
1233 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
1234 if let GenericParam::NegImpl(neg_impl) = generic_param
1235 && Ok(concrete_trait_id) == neg_impl.concrete_trait
1236 {
1237 return Ok(SolutionSet::Unique(concrete_trait_id));
1238 }
1239 }
1240
1241 let generic_args = concrete_trait_id.generic_args(self.db);
1242 if generic_args.iter().any(|garg| {
1246 matches!(
1247 garg,
1248 GenericArgumentId::Type(ty)
1249 if !matches!(
1250 ty.long(self.db),
1251 TypeLongId::Closure(_) | TypeLongId::NumericLiteral(_)
1252 )
1253 && !ty.is_var_free(self.db)
1254 )
1255 }) {
1256 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1257 }
1258
1259 let mut neg_impl_generic_params = OrderedHashSet::default();
1260 let mut visited_types = OrderedHashSet::default();
1261 for garg in generic_args {
1262 if garg
1263 .extract_generic_params(self.db, &mut neg_impl_generic_params, &mut visited_types)
1264 .is_err()
1265 {
1266 return Ok(SolutionSet::Ambiguous(
1267 Ambiguity::NegativeImplWithUnsupportedExtractedArgs(*garg),
1268 ));
1269 }
1270 }
1271
1272 let solution_set = if neg_impl_generic_params.is_empty() {
1273 self.trait_solution_set(
1274 concrete_trait_id,
1275 ImplVarTraitItemMappings::default(),
1276 lookup_context,
1277 )?
1278 } else {
1279 let mut substitution: OrderedHashMap<GenericParamId<'db>, GenericArgumentId<'db>> =
1280 Default::default();
1281
1282 for param in neg_impl_generic_params {
1283 let garg = match param.kind(self.db) {
1284 GenericKind::Type => GenericArgumentId::Type(
1285 self.new_type_var(Some(param.stable_ptr(self.db).untyped())),
1286 ),
1287 GenericKind::Const | GenericKind::Impl | GenericKind::NegImpl => {
1288 return Ok(SolutionSet::Ambiguous(
1289 Ambiguity::NegativeImplWithUnsupportedGenericParam(param),
1290 ));
1291 }
1292 };
1293
1294 substitution.insert(param, garg);
1295 }
1296 let rewritten_concrete_trait_id =
1297 GenericSubstitution { param_to_arg: substitution.clone(), self_impl: None }
1298 .substitute(self.db, concrete_trait_id)
1299 .unwrap();
1300
1301 let solution_set = self.trait_solution_set(
1302 rewritten_concrete_trait_id,
1303 ImplVarTraitItemMappings::default(),
1304 lookup_context,
1305 )?;
1306
1307 let db = self.db;
1310 for (generic_param, garg) in substitution {
1311 let GenericArgumentId::Type(ty) = garg else {
1312 panic!("Expected a type variable");
1313 };
1314 let (TypeLongId::Var(var) | TypeLongId::NumericLiteral(var)) = ty.long(self.db)
1315 else {
1316 panic!("Expected a type variable");
1317 };
1318 self.type_assignment
1319 .entry(var.id)
1320 .or_insert_with(|| TypeLongId::GenericParameter(generic_param).intern(db));
1321 }
1322
1323 solution_set
1324 };
1325
1326 if !matches!(solution_set, SolutionSet::None) {
1327 return Ok(SolutionSet::None);
1328 }
1329
1330 Ok(SolutionSet::Unique(concrete_trait_id))
1331 }
1332
1333 pub fn trait_solution_set(
1335 &mut self,
1336 concrete_trait_id: ConcreteTraitId<'db>,
1337 impl_var_trait_item_mappings: ImplVarTraitItemMappings<'db>,
1338 lookup_context_id: ImplLookupContextId<'db>,
1339 ) -> InferenceResult<SolutionSet<'db, (CanonicalImpl<'db>, CanonicalMapping<'db>)>> {
1340 let impl_var_trait_item_mappings = self.rewrite(impl_var_trait_item_mappings).no_err();
1341 let concrete_trait_id = self.rewrite(concrete_trait_id).no_err();
1343 let mut lookup_context = lookup_context_id.long(self.db).clone();
1344 enrich_lookup_context(self.db, concrete_trait_id, &mut lookup_context);
1345
1346 let generic_args = concrete_trait_id.generic_args(self.db);
1348 match generic_args.first() {
1349 Some(GenericArgumentId::Type(ty)) => {
1350 if let TypeLongId::Var(_) = ty.long(self.db) {
1351 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1353 }
1354 }
1355 Some(GenericArgumentId::Impl(imp)) => {
1356 if let ImplLongId::ImplVar(_) = imp.long(self.db) {
1358 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1359 }
1360 }
1361 Some(GenericArgumentId::Constant(const_value)) => {
1362 if let ConstValue::Var(_, _) = const_value.long(self.db) {
1363 return Ok(SolutionSet::Ambiguous(Ambiguity::WillNotInfer(concrete_trait_id)));
1365 }
1366 }
1367 _ => {}
1368 };
1369 let (canonical_trait, canonicalizer) = CanonicalTrait::canonicalize(
1370 self.db,
1371 self.inference_id,
1372 concrete_trait_id,
1373 impl_var_trait_item_mappings,
1374 );
1375 let solution_set = match self.db.canonic_trait_solutions(
1378 canonical_trait,
1379 lookup_context.intern(self.db),
1380 (*self.data.impl_type_bounds).clone(),
1381 ) {
1382 Ok(solution_set) => solution_set,
1383 Err(err) => return Err(self.set_error(err)),
1384 };
1385 match solution_set {
1386 SolutionSet::None => Ok(SolutionSet::None),
1387 SolutionSet::Unique(canonical_impl) => {
1388 Ok(SolutionSet::Unique((canonical_impl, canonicalizer)))
1389 }
1390 SolutionSet::Ambiguous(ambiguity) => Ok(SolutionSet::Ambiguous(ambiguity)),
1391 }
1392 }
1393
1394 pub fn set_error(&mut self, err: InferenceError<'db>) -> ErrorSet {
1401 self.set_error_ex(err, None)
1402 }
1403
1404 pub fn set_error_ex(
1408 &mut self,
1409 err: InferenceError<'db>,
1410 stable_ptr: Option<SyntaxStablePtrId<'db>>,
1411 ) -> ErrorSet {
1412 if self.error_status.is_err() {
1413 return ErrorSet;
1414 }
1415 self.error_status = Err(if let InferenceError::Reported(diag_added) = err {
1416 InferenceErrorStatus::Consumed(diag_added)
1417 } else {
1418 InferenceErrorStatus::Pending(PendingInferenceError { err, stable_ptr })
1419 });
1420 ErrorSet
1421 }
1422
1423 pub fn set_error_on_var(&mut self, err: InferenceError<'db>, var: InferenceVar) -> ErrorSet {
1427 self.set_error_ex(err, self.stable_ptrs.get(&var).cloned())
1428 }
1429
1430 pub fn is_error_set(&self) -> InferenceResult<()> {
1432 self.error_status.as_ref().copied().map_err(|_| ErrorSet)
1433 }
1434
1435 fn add_error_stable_ptr(&mut self, var: InferenceVar) {
1437 let var_stable_ptr = self.stable_ptrs.get(&var).copied();
1438 if let Err(InferenceErrorStatus::Pending(PendingInferenceError { err: _, stable_ptr })) =
1439 &mut self.error_status
1440 && stable_ptr.is_none()
1441 {
1442 *stable_ptr = var_stable_ptr;
1443 }
1444 }
1445
1446 pub fn consume_error_without_reporting(
1452 &mut self,
1453 err_set: ErrorSet,
1454 ) -> Option<InferenceError<'db>> {
1455 Some(self.consume_error_inner(err_set, skip_diagnostic())?.err)
1456 }
1457
1458 pub fn consume_reported_error(&mut self, err_set: ErrorSet, diag_added: DiagnosticAdded) {
1465 self.consume_error_inner(err_set, diag_added);
1466 }
1467
1468 fn consume_error_inner(
1475 &mut self,
1476 _err_set: ErrorSet,
1477 diag_added: DiagnosticAdded,
1478 ) -> Option<PendingInferenceError<'db>> {
1479 match &mut self.error_status {
1480 Err(InferenceErrorStatus::Pending(error)) => {
1481 let pending_error = std::mem::replace(
1482 error,
1483 PendingInferenceError {
1484 err: InferenceError::Reported(diag_added),
1485 stable_ptr: None,
1486 },
1487 );
1488 self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
1489 Some(pending_error)
1490 }
1491 _ => None,
1493 }
1494 }
1495
1496 pub fn report_on_pending_error(
1502 &mut self,
1503 _err_set: ErrorSet,
1504 diagnostics: &mut SemanticDiagnostics<'db>,
1505 stable_ptr: SyntaxStablePtrId<'db>,
1506 ) -> DiagnosticAdded {
1507 let Err(state_error) = &self.error_status else {
1508 panic!("report_on_pending_error should be called only on error");
1509 };
1510 match state_error {
1511 InferenceErrorStatus::Consumed(diag_added) => *diag_added,
1512 InferenceErrorStatus::Pending(pending) => {
1513 let diag_added = match &pending.err {
1514 InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
1515 skip_diagnostic()
1520 }
1521 diag => diag.report(diagnostics, pending.stable_ptr.unwrap_or(stable_ptr)),
1522 };
1523 self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
1524 diag_added
1525 }
1526 }
1527 }
1528
1529 pub fn report_modified_if_pending(
1532 &mut self,
1533 err_set: ErrorSet,
1534 report: impl FnOnce() -> DiagnosticAdded,
1535 ) {
1536 if matches!(self.error_status, Err(InferenceErrorStatus::Pending { .. })) {
1537 self.consume_reported_error(err_set, report());
1538 }
1539 }
1540}
1541
1542impl<'a, 'mt> HasDb<&'a dyn Database> for Inference<'a, 'mt> {
1543 fn get_db(&self) -> &'a dyn Database {
1544 self.db
1545 }
1546}
1547add_basic_rewrites!(<'a, 'mt>, Inference<'a, 'mt>, NoError, @exclude TypeLongId TypeId ImplLongId ImplId ConstValue NegativeImplLongId NegativeImplId);
1548add_expr_rewrites!(<'a, 'mt>, Inference<'a, 'mt>, NoError, @exclude);
1549add_rewrite!(<'a, 'mt>, Inference<'a, 'mt>, NoError, Ambiguity<'a>);
1550impl<'db, 'mt> SemanticRewriter<TypeId<'db>, NoError> for Inference<'db, 'mt> {
1551 fn internal_rewrite(&mut self, value: &mut TypeId<'db>) -> Result<RewriteResult, NoError> {
1552 if value.is_var_free(self.db) {
1553 return Ok(RewriteResult::NoChange);
1554 }
1555 value.default_rewrite(self)
1556 }
1557}
1558impl<'db, 'mt> SemanticRewriter<ImplId<'db>, NoError> for Inference<'db, 'mt> {
1559 fn internal_rewrite(&mut self, value: &mut ImplId<'db>) -> Result<RewriteResult, NoError> {
1560 if value.is_var_free(self.db) {
1561 return Ok(RewriteResult::NoChange);
1562 }
1563 value.default_rewrite(self)
1564 }
1565}
1566impl<'db, 'mt> SemanticRewriter<NegativeImplId<'db>, NoError> for Inference<'db, 'mt> {
1567 fn internal_rewrite(
1568 &mut self,
1569 value: &mut NegativeImplId<'db>,
1570 ) -> Result<RewriteResult, NoError> {
1571 if value.is_var_free(self.db) {
1572 return Ok(RewriteResult::NoChange);
1573 }
1574 value.default_rewrite(self)
1575 }
1576}
1577
1578impl<'db, 'mt> SemanticRewriter<TypeLongId<'db>, NoError> for Inference<'db, 'mt> {
1579 fn internal_rewrite(&mut self, value: &mut TypeLongId<'db>) -> Result<RewriteResult, NoError> {
1580 match value {
1581 TypeLongId::Var(var) | TypeLongId::NumericLiteral(var) => {
1582 if let Some(type_id) = self.type_assignment.get(&var.id) {
1583 let mut long_type_id = type_id.long(self.db).clone();
1584 if let RewriteResult::Modified = self.internal_rewrite(&mut long_type_id)? {
1585 *self.type_assignment.get_mut(&var.id).unwrap() =
1586 long_type_id.clone().intern(self.db);
1587 }
1588 *value = long_type_id;
1589 return Ok(RewriteResult::Modified);
1590 }
1591 }
1592 TypeLongId::ImplType(impl_type_id) => {
1593 if let Some(type_id) = self.impl_type_bounds.get(&((*impl_type_id).into())) {
1594 *value = type_id.long(self.db).clone();
1595 self.internal_rewrite(value)?;
1596 return Ok(RewriteResult::Modified);
1597 }
1598 let impl_type_id_rewrite_result = self.internal_rewrite(impl_type_id)?;
1599 let impl_id = impl_type_id.impl_id();
1600 let trait_ty = impl_type_id.ty();
1601 return Ok(match impl_id.long(self.db) {
1602 ImplLongId::GenericParameter(_)
1603 | ImplLongId::SelfImpl(_)
1604 | ImplLongId::ImplImpl(_) => impl_type_id_rewrite_result,
1605 ImplLongId::Concrete(_) => {
1606 if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new(
1607 impl_id, trait_ty, self.db,
1608 )) {
1609 *value = self.rewrite(ty).no_err().long(self.db).clone();
1610 RewriteResult::Modified
1611 } else {
1612 impl_type_id_rewrite_result
1613 }
1614 }
1615 ImplLongId::ImplVar(var) => {
1616 *value = self.rewritten_impl_type(*var, trait_ty).long(self.db).clone();
1617 return Ok(RewriteResult::Modified);
1618 }
1619 ImplLongId::GeneratedImpl(generated) => {
1620 *value = self
1621 .rewrite(
1622 *generated
1623 .long(self.db)
1624 .impl_items
1625 .0
1626 .get(&impl_type_id.ty())
1627 .unwrap(),
1628 )
1629 .no_err()
1630 .long(self.db)
1631 .clone();
1632 RewriteResult::Modified
1633 }
1634 });
1635 }
1636 _ => {}
1637 }
1638 value.default_rewrite(self)
1639 }
1640}
1641impl<'db, 'mt> SemanticRewriter<ConstValue<'db>, NoError> for Inference<'db, 'mt> {
1642 fn internal_rewrite(&mut self, value: &mut ConstValue<'db>) -> Result<RewriteResult, NoError> {
1643 match value {
1644 ConstValue::Var(var, _) => {
1645 return Ok(if let Some(const_value_id) = self.const_assignment.get(&var.id) {
1646 let mut const_value = const_value_id.long(self.db).clone();
1647 if let RewriteResult::Modified = self.internal_rewrite(&mut const_value)? {
1648 *self.const_assignment.get_mut(&var.id).unwrap() =
1649 const_value.clone().intern(self.db);
1650 }
1651 *value = const_value;
1652 RewriteResult::Modified
1653 } else {
1654 RewriteResult::NoChange
1655 });
1656 }
1657 ConstValue::ImplConstant(impl_constant_id) => {
1658 let impl_constant_id_rewrite_result = self.internal_rewrite(impl_constant_id)?;
1659 let impl_id = impl_constant_id.impl_id();
1660 let trait_constant = impl_constant_id.trait_constant_id();
1661 return Ok(match impl_id.long(self.db) {
1662 ImplLongId::GenericParameter(_)
1663 | ImplLongId::SelfImpl(_)
1664 | ImplLongId::GeneratedImpl(_)
1665 | ImplLongId::ImplImpl(_) => impl_constant_id_rewrite_result,
1666 ImplLongId::Concrete(_) => {
1667 if let Ok(constant) = self.db.impl_constant_concrete_implized_value(
1668 ImplConstantId::new(impl_id, trait_constant, self.db),
1669 ) {
1670 *value = self.rewrite(constant).no_err().long(self.db).clone();
1671 RewriteResult::Modified
1672 } else {
1673 impl_constant_id_rewrite_result
1674 }
1675 }
1676 ImplLongId::ImplVar(var) => {
1677 *value = self
1678 .rewritten_impl_constant(*var, trait_constant)
1679 .long(self.db)
1680 .clone();
1681 return Ok(RewriteResult::Modified);
1682 }
1683 });
1684 }
1685 _ => {}
1686 }
1687 value.default_rewrite(self)
1688 }
1689}
1690impl<'db, 'mt> SemanticRewriter<ImplLongId<'db>, NoError> for Inference<'db, 'mt> {
1691 fn internal_rewrite(&mut self, value: &mut ImplLongId<'db>) -> Result<RewriteResult, NoError> {
1692 match value {
1693 ImplLongId::ImplVar(var) => {
1694 let long_id = var.long(self.db);
1695 let impl_var_id = long_id.id;
1697 if let Some(impl_id) = self.impl_assignment(impl_var_id) {
1698 let mut long_impl_id = impl_id.long(self.db).clone();
1699 if let RewriteResult::Modified = self.internal_rewrite(&mut long_impl_id)? {
1700 *self.impl_assignment.get_mut(&impl_var_id).unwrap() =
1701 long_impl_id.clone().intern(self.db);
1702 }
1703 *value = long_impl_id;
1704 return Ok(RewriteResult::Modified);
1705 }
1706 }
1707 ImplLongId::GeneratedImpl(generated) => {
1708 let inner_result = generated.default_rewrite(self)?;
1710 let concrete_trait = generated.concrete_trait(self.db);
1711 if let Some(GenericArgumentId::Type(first_ty)) =
1712 concrete_trait.generic_args(self.db).first().copied()
1713 && !matches!(
1714 first_ty.long(self.db),
1715 TypeLongId::Closure(_)
1716 | TypeLongId::Var(_)
1717 | TypeLongId::NumericLiteral(_)
1718 | TypeLongId::ImplType(_)
1719 )
1720 {
1721 let info = self.db.core_info();
1722 let trait_id = concrete_trait.trait_id(self.db);
1723 if [info.drop_trt, info.copy_trt, info.destruct_trt, info.panic_destruct_trt]
1724 .contains(&trait_id)
1725 {
1726 let lookup_context =
1729 ImplLookupContext::new_from_type(first_ty, self.db).intern(self.db);
1730 if let Ok(real_impl) =
1731 get_impl_at_context(self.db, lookup_context, concrete_trait, None)
1732 {
1733 *value = real_impl.long(self.db).clone();
1734 return Ok(RewriteResult::Modified);
1735 }
1736 }
1737 }
1738 return Ok(inner_result);
1739 }
1740 ImplLongId::ImplImpl(impl_impl_id) => {
1741 let impl_impl_id_rewrite_result = self.internal_rewrite(impl_impl_id)?;
1742 let impl_id = impl_impl_id.impl_id();
1743 return Ok(match impl_id.long(self.db) {
1744 ImplLongId::GenericParameter(_)
1745 | ImplLongId::SelfImpl(_)
1746 | ImplLongId::GeneratedImpl(_)
1747 | ImplLongId::ImplImpl(_) => impl_impl_id_rewrite_result,
1748 ImplLongId::Concrete(_) => {
1749 if let Ok(imp) = self.db.impl_impl_concrete_implized(*impl_impl_id) {
1750 *value = self.rewrite(imp).no_err().long(self.db).clone();
1751 RewriteResult::Modified
1752 } else {
1753 impl_impl_id_rewrite_result
1754 }
1755 }
1756 ImplLongId::ImplVar(var) => {
1757 if let Ok(concrete_trait_impl) =
1758 impl_impl_id.concrete_trait_impl_id(self.db)
1759 {
1760 *value = self
1761 .rewritten_impl_impl(*var, concrete_trait_impl)
1762 .long(self.db)
1763 .clone();
1764 return Ok(RewriteResult::Modified);
1765 } else {
1766 impl_impl_id_rewrite_result
1767 }
1768 }
1769 });
1770 }
1771
1772 _ => {}
1773 }
1774 if value.is_var_free(self.db) {
1775 return Ok(RewriteResult::NoChange);
1776 }
1777 value.default_rewrite(self)
1778 }
1779}
1780
1781impl<'db, 'mt> SemanticRewriter<NegativeImplLongId<'db>, NoError> for Inference<'db, 'mt> {
1782 fn internal_rewrite(
1783 &mut self,
1784 value: &mut NegativeImplLongId<'db>,
1785 ) -> Result<RewriteResult, NoError> {
1786 if let NegativeImplLongId::NegativeImplVar(var) = value {
1787 let long_id = var.long(self.db);
1788 let neg_impl_var_id = long_id.id;
1790 if let Some(impl_id) = self.negative_impl_assignment(neg_impl_var_id) {
1791 let mut long_neg_impl_id = impl_id.long(self.db).clone();
1792 if let RewriteResult::Modified = self.internal_rewrite(&mut long_neg_impl_id)? {
1793 *self.negative_impl_assignment.get_mut(&neg_impl_var_id).unwrap() =
1794 long_neg_impl_id.clone().intern(self.db);
1795 }
1796 *value = long_neg_impl_id;
1797 return Ok(RewriteResult::Modified);
1798 }
1799 }
1800
1801 if value.is_var_free(self.db) {
1802 return Ok(RewriteResult::NoChange);
1803 }
1804 value.default_rewrite(self)
1805 }
1806}
1807
1808struct InferenceIdReplacer<'a> {
1809 db: &'a dyn Database,
1810 from_inference_id: InferenceId<'a>,
1811 to_inference_id: InferenceId<'a>,
1812}
1813impl<'a> InferenceIdReplacer<'a> {
1814 fn new(
1815 db: &'a dyn Database,
1816 from_inference_id: InferenceId<'a>,
1817 to_inference_id: InferenceId<'a>,
1818 ) -> Self {
1819 Self { db, from_inference_id, to_inference_id }
1820 }
1821}
1822impl<'a> HasDb<&'a dyn Database> for InferenceIdReplacer<'a> {
1823 fn get_db(&self) -> &'a dyn Database {
1824 self.db
1825 }
1826}
1827add_basic_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude InferenceId);
1828add_expr_rewrites!(<'a>, InferenceIdReplacer<'a>, NoError, @exclude);
1829add_rewrite!(<'a>, InferenceIdReplacer<'a>, NoError, Ambiguity<'a>);
1830impl<'a> SemanticRewriter<InferenceId<'a>, NoError> for InferenceIdReplacer<'a> {
1831 fn internal_rewrite(&mut self, value: &mut InferenceId<'a>) -> Result<RewriteResult, NoError> {
1832 if value == &self.from_inference_id {
1833 *value = self.to_inference_id;
1834 Ok(RewriteResult::Modified)
1835 } else {
1836 Ok(RewriteResult::NoChange)
1837 }
1838 }
1839}