1use cairo_lang_defs::ids::{ImplAliasId, ImplDefId, TraitFunctionId};
2use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
3use cairo_lang_utils::{Intern, extract_matches, require};
4use itertools::Itertools;
5
6use super::canonic::ResultNoErrEx;
7use super::conform::InferenceConform;
8use super::{Inference, InferenceError, InferenceResult};
9use crate::items::constant::ImplConstantId;
10use crate::items::functions::{GenericFunctionId, ImplGenericFunctionId};
11use crate::items::generics::{GenericParamConst, GenericParamSemantic};
12use crate::items::imp::{
13 GeneratedImplLongId, ImplId, ImplImplId, ImplLongId, ImplLookupContextId, ImplSemantic,
14 UninferredImpl,
15};
16use crate::items::impl_alias::ImplAliasSemantic;
17use crate::items::trt::{
18 ConcreteTraitConstantId, ConcreteTraitGenericFunctionId, ConcreteTraitImplId,
19 ConcreteTraitTypeId, TraitSemantic,
20};
21use crate::keyword::SELF_PARAM_KW;
22use crate::substitution::{GenericSubstitution, SemanticRewriter};
23use crate::types::ImplTypeId;
24use crate::{
25 ConcreteFunction, ConcreteImplLongId, ConcreteTraitId, ConcreteTraitLongId, FunctionId,
26 FunctionLongId, GenericArgumentId, GenericParam, TypeId, TypeLongId,
27};
28
29pub trait InferenceEmbeddings<'db> {
32 fn infer_impl(
33 &mut self,
34 uninferred_impl: UninferredImpl<'db>,
35 concrete_trait_id: ConcreteTraitId<'db>,
36 lookup_context: ImplLookupContextId<'db>,
37 stable_ptr: Option<SyntaxStablePtrId<'db>>,
38 ) -> InferenceResult<ImplId<'db>>;
39 fn infer_impl_def(
40 &mut self,
41 impl_def_id: ImplDefId<'db>,
42 concrete_trait_id: ConcreteTraitId<'db>,
43 lookup_context: ImplLookupContextId<'db>,
44 stable_ptr: Option<SyntaxStablePtrId<'db>>,
45 ) -> InferenceResult<ImplId<'db>>;
46 fn infer_impl_alias(
47 &mut self,
48 impl_alias_id: ImplAliasId<'db>,
49 concrete_trait_id: ConcreteTraitId<'db>,
50 lookup_context: ImplLookupContextId<'db>,
51 stable_ptr: Option<SyntaxStablePtrId<'db>>,
52 ) -> InferenceResult<ImplId<'db>>;
53 fn infer_generic_assignment(
54 &mut self,
55 generic_params: &[GenericParam<'db>],
56 generic_args: &[GenericArgumentId<'db>],
57 expected_generic_args: &[GenericArgumentId<'db>],
58 lookup_context: ImplLookupContextId<'db>,
59 stable_ptr: Option<SyntaxStablePtrId<'db>>,
60 ) -> InferenceResult<Vec<GenericArgumentId<'db>>>;
61 fn infer_generic_args(
62 &mut self,
63 generic_params: &[GenericParam<'db>],
64 lookup_context: ImplLookupContextId<'db>,
65 stable_ptr: Option<SyntaxStablePtrId<'db>>,
66 ) -> InferenceResult<Vec<GenericArgumentId<'db>>>;
67 fn infer_concrete_trait_by_self(
68 &mut self,
69 trait_function: TraitFunctionId<'db>,
70 self_ty: TypeId<'db>,
71 lookup_context: ImplLookupContextId<'db>,
72 stable_ptr: Option<SyntaxStablePtrId<'db>>,
73 inference_errors: &mut Vec<(TraitFunctionId<'db>, InferenceError<'db>)>,
74 ) -> Option<(ConcreteTraitId<'db>, usize)>;
75 fn infer_concrete_trait_by_self_without_errors(
76 &mut self,
77 trait_function: TraitFunctionId<'db>,
78 self_ty: TypeId<'db>,
79 lookup_context: ImplLookupContextId<'db>,
80 stable_ptr: Option<SyntaxStablePtrId<'db>>,
81 ) -> Option<(ConcreteTraitId<'db>, usize)>;
82 fn infer_generic_arg(
83 &mut self,
84 param: &GenericParam<'db>,
85 lookup_context: ImplLookupContextId<'db>,
86 stable_ptr: Option<SyntaxStablePtrId<'db>>,
87 ) -> InferenceResult<GenericArgumentId<'db>>;
88 fn infer_trait_function(
89 &mut self,
90 concrete_trait_function: ConcreteTraitGenericFunctionId<'db>,
91 lookup_context: ImplLookupContextId<'db>,
92 stable_ptr: Option<SyntaxStablePtrId<'db>>,
93 ) -> InferenceResult<FunctionId<'db>>;
94 fn infer_generic_function(
95 &mut self,
96 generic_function: GenericFunctionId<'db>,
97 lookup_context: ImplLookupContextId<'db>,
98 stable_ptr: Option<SyntaxStablePtrId<'db>>,
99 ) -> InferenceResult<FunctionId<'db>>;
100 fn infer_trait_generic_function(
101 &mut self,
102 concrete_trait_function: ConcreteTraitGenericFunctionId<'db>,
103 lookup_context: ImplLookupContextId<'db>,
104 stable_ptr: Option<SyntaxStablePtrId<'db>>,
105 ) -> ImplGenericFunctionId<'db>;
106 fn infer_trait_type(
107 &mut self,
108 concrete_trait_type: ConcreteTraitTypeId<'db>,
109 lookup_context: ImplLookupContextId<'db>,
110 stable_ptr: Option<SyntaxStablePtrId<'db>>,
111 ) -> TypeId<'db>;
112 fn infer_trait_constant(
113 &mut self,
114 concrete_trait_constant: ConcreteTraitConstantId<'db>,
115 lookup_context: ImplLookupContextId<'db>,
116 stable_ptr: Option<SyntaxStablePtrId<'db>>,
117 ) -> ImplConstantId<'db>;
118 fn infer_trait_impl(
119 &mut self,
120 concrete_trait_constant: ConcreteTraitImplId<'db>,
121 lookup_context: ImplLookupContextId<'db>,
122 stable_ptr: Option<SyntaxStablePtrId<'db>>,
123 ) -> ImplImplId<'db>;
124}
125
126impl<'db> InferenceEmbeddings<'db> for Inference<'db, '_> {
127 fn infer_impl(
129 &mut self,
130 uninferred_impl: UninferredImpl<'db>,
131 concrete_trait_id: ConcreteTraitId<'db>,
132 lookup_context: ImplLookupContextId<'db>,
133 stable_ptr: Option<SyntaxStablePtrId<'db>>,
134 ) -> InferenceResult<ImplId<'db>> {
135 let impl_id = match uninferred_impl {
136 UninferredImpl::Def(impl_def_id) => {
137 self.infer_impl_def(impl_def_id, concrete_trait_id, lookup_context, stable_ptr)?
138 }
139 UninferredImpl::ImplAlias(impl_alias_id) => {
140 self.infer_impl_alias(impl_alias_id, concrete_trait_id, lookup_context, stable_ptr)?
141 }
142 UninferredImpl::ImplImpl(impl_impl_id) => {
143 ImplLongId::ImplImpl(impl_impl_id).intern(self.db)
144 }
145 UninferredImpl::GenericParam(param_id) => {
146 let param = self
147 .db
148 .generic_param_semantic(param_id)
149 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
150 let param = extract_matches!(param, GenericParam::Impl);
151 let imp_concrete_trait_id = param.concrete_trait.unwrap();
152 self.conform_traits(concrete_trait_id, imp_concrete_trait_id)?;
153 ImplLongId::GenericParameter(param_id).intern(self.db)
154 }
155 UninferredImpl::GeneratedImpl(generated_impl) => {
156 let long_id = generated_impl.long(self.db);
157
158 self.infer_generic_args(&long_id.generic_params[..], lookup_context, stable_ptr)?;
160
161 ImplLongId::GeneratedImpl(
162 GeneratedImplLongId {
163 concrete_trait: long_id.concrete_trait,
164 generic_params: long_id.generic_params.clone(),
165 impl_items: long_id.impl_items.clone(),
166 }
167 .intern(self.db),
168 )
169 .intern(self.db)
170 }
171 };
172 Ok(impl_id)
173 }
174
175 fn infer_impl_def(
178 &mut self,
179 impl_def_id: ImplDefId<'db>,
180 concrete_trait_id: ConcreteTraitId<'db>,
181 lookup_context: ImplLookupContextId<'db>,
182 stable_ptr: Option<SyntaxStablePtrId<'db>>,
183 ) -> InferenceResult<ImplId<'db>> {
184 let imp_generic_params = self
185 .db
186 .impl_def_generic_params(impl_def_id)
187 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
188 let imp_concrete_trait = self
189 .db
190 .impl_def_concrete_trait(impl_def_id)
191 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
192 if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
193 return Err(self.set_error(InferenceError::TraitMismatch {
194 trt0: imp_concrete_trait.trait_id(self.db),
195 trt1: concrete_trait_id.trait_id(self.db),
196 }));
197 }
198
199 let long_concrete_trait = concrete_trait_id.long(self.db);
200 let long_imp_concrete_trait = imp_concrete_trait.long(self.db);
201 let generic_args = self.infer_generic_assignment(
202 imp_generic_params,
203 &long_imp_concrete_trait.generic_args,
204 &long_concrete_trait.generic_args,
205 lookup_context,
206 stable_ptr,
207 )?;
208 Ok(ImplLongId::Concrete(ConcreteImplLongId { impl_def_id, generic_args }.intern(self.db))
209 .intern(self.db))
210 }
211
212 fn infer_impl_alias(
215 &mut self,
216 impl_alias_id: ImplAliasId<'db>,
217 concrete_trait_id: ConcreteTraitId<'db>,
218 lookup_context: ImplLookupContextId<'db>,
219 stable_ptr: Option<SyntaxStablePtrId<'db>>,
220 ) -> InferenceResult<ImplId<'db>> {
221 let impl_alias_generic_params = self
222 .db
223 .impl_alias_generic_params(impl_alias_id)
224 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
225 let impl_id = self
226 .db
227 .impl_alias_resolved_impl(impl_alias_id)
228 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
229 let imp_concrete_trait = impl_id
230 .concrete_trait(self.db)
231 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
232 if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
233 return Err(self.set_error(InferenceError::TraitMismatch {
234 trt0: imp_concrete_trait.trait_id(self.db),
235 trt1: concrete_trait_id.trait_id(self.db),
236 }));
237 }
238
239 let long_concrete_trait = concrete_trait_id.long(self.db);
240 let long_imp_concrete_trait = imp_concrete_trait.long(self.db);
241 let generic_args = self.infer_generic_assignment(
242 &impl_alias_generic_params,
243 &long_imp_concrete_trait.generic_args,
244 &long_concrete_trait.generic_args,
245 lookup_context,
246 stable_ptr,
247 )?;
248
249 GenericSubstitution::new(&impl_alias_generic_params, &generic_args)
250 .substitute(self.db, impl_id)
251 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))
252 }
253
254 fn infer_generic_assignment(
258 &mut self,
259 generic_params: &[GenericParam<'db>],
260 generic_args: &[GenericArgumentId<'db>],
261 expected_generic_args: &[GenericArgumentId<'db>],
262 lookup_context: ImplLookupContextId<'db>,
263 stable_ptr: Option<SyntaxStablePtrId<'db>>,
264 ) -> InferenceResult<Vec<GenericArgumentId<'db>>> {
265 let new_generic_args =
266 self.infer_generic_args(generic_params, lookup_context, stable_ptr)?;
267 let substitution = GenericSubstitution::new(generic_params, &new_generic_args);
268 let generic_args = substitution
269 .substitute(self.db, generic_args.iter().copied().collect_vec())
270 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
271 self.conform_generic_args(&generic_args, expected_generic_args)?;
272 Ok(self.rewrite(new_generic_args).no_err())
273 }
274
275 fn infer_generic_args(
277 &mut self,
278 generic_params: &[GenericParam<'db>],
279 lookup_context: ImplLookupContextId<'db>,
280 stable_ptr: Option<SyntaxStablePtrId<'db>>,
281 ) -> InferenceResult<Vec<GenericArgumentId<'db>>> {
282 let mut generic_args = vec![];
283 let mut substitution = GenericSubstitution::default();
284 for generic_param in generic_params {
285 let generic_param = substitution
286 .substitute(self.db, generic_param.clone())
287 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
288 let generic_arg = self.infer_generic_arg(&generic_param, lookup_context, stable_ptr)?;
289 generic_args.push(generic_arg);
290 substitution.insert(generic_param.id(), generic_arg);
291 }
292 Ok(generic_args)
293 }
294
295 fn infer_concrete_trait_by_self(
303 &mut self,
304 trait_function: TraitFunctionId<'db>,
305 self_ty: TypeId<'db>,
306 lookup_context: ImplLookupContextId<'db>,
307 stable_ptr: Option<SyntaxStablePtrId<'db>>,
308 inference_errors: &mut Vec<(TraitFunctionId<'db>, InferenceError<'db>)>,
309 ) -> Option<(ConcreteTraitId<'db>, usize)> {
310 infer_concrete_trait_by_self(
311 self,
312 trait_function,
313 self_ty,
314 lookup_context,
315 stable_ptr,
316 inference_errors,
317 )
318 }
319 fn infer_concrete_trait_by_self_without_errors(
322 &mut self,
323 trait_function: TraitFunctionId<'db>,
324 self_ty: TypeId<'db>,
325 lookup_context: ImplLookupContextId<'db>,
326 stable_ptr: Option<SyntaxStablePtrId<'db>>,
327 ) -> Option<(ConcreteTraitId<'db>, usize)> {
328 infer_concrete_trait_by_self(
329 self,
330 trait_function,
331 self_ty,
332 lookup_context,
333 stable_ptr,
334 &mut vec![],
335 )
336 }
337
338 fn infer_generic_arg(
341 &mut self,
342 param: &GenericParam<'db>,
343 lookup_context: ImplLookupContextId<'db>,
344 stable_ptr: Option<SyntaxStablePtrId<'db>>,
345 ) -> InferenceResult<GenericArgumentId<'db>> {
346 match param {
347 GenericParam::Type(_) => Ok(GenericArgumentId::Type(self.new_type_var(stable_ptr))),
348 GenericParam::Impl(param) => {
349 let concrete_trait_id = param
350 .concrete_trait
351 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
352 let impl_id = self.new_impl_var(concrete_trait_id, stable_ptr, lookup_context);
353 for (trait_ty, ty1) in param.type_constraints.iter() {
354 let ty0 = self.reduce_impl_ty(ImplTypeId::new(impl_id, *trait_ty, self.db))?;
355 self.conform_ty(ty0, *ty1).ok();
357 }
358 Ok(GenericArgumentId::Impl(impl_id))
359 }
360 GenericParam::Const(GenericParamConst { ty, .. }) => {
361 Ok(GenericArgumentId::Constant(self.new_const_var(stable_ptr, *ty)))
362 }
363 GenericParam::NegImpl(param) => {
364 let concrete_trait_id = param
365 .concrete_trait
366 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
367 let impl_id =
368 self.new_negative_impl_var(concrete_trait_id, stable_ptr, lookup_context);
369 Ok(GenericArgumentId::NegImpl(impl_id))
370 }
371 }
372 }
373
374 fn infer_trait_function(
378 &mut self,
379 concrete_trait_function: ConcreteTraitGenericFunctionId<'db>,
380 lookup_context: ImplLookupContextId<'db>,
381 stable_ptr: Option<SyntaxStablePtrId<'db>>,
382 ) -> InferenceResult<FunctionId<'db>> {
383 let generic_function = GenericFunctionId::Impl(self.infer_trait_generic_function(
384 concrete_trait_function,
385 lookup_context,
386 stable_ptr,
387 ));
388 self.infer_generic_function(generic_function, lookup_context, stable_ptr)
389 }
390
391 fn infer_generic_function(
394 &mut self,
395 generic_function: GenericFunctionId<'db>,
396 lookup_context: ImplLookupContextId<'db>,
397 stable_ptr: Option<SyntaxStablePtrId<'db>>,
398 ) -> InferenceResult<FunctionId<'db>> {
399 let generic_params = generic_function
400 .generic_params(self.db)
401 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
402 let generic_args = self.infer_generic_args(generic_params, lookup_context, stable_ptr)?;
403 Ok(FunctionLongId { function: ConcreteFunction { generic_function, generic_args } }
404 .intern(self.db))
405 }
406
407 fn infer_trait_generic_function(
410 &mut self,
411 concrete_trait_function: ConcreteTraitGenericFunctionId<'db>,
412 lookup_context: ImplLookupContextId<'db>,
413 stable_ptr: Option<SyntaxStablePtrId<'db>>,
414 ) -> ImplGenericFunctionId<'db> {
415 let impl_id = self.new_impl_var(
416 concrete_trait_function.concrete_trait(self.db),
417 stable_ptr,
418 lookup_context,
419 );
420 ImplGenericFunctionId { impl_id, function: concrete_trait_function.trait_function(self.db) }
421 }
422
423 fn infer_trait_type(
426 &mut self,
427 concrete_trait_type: ConcreteTraitTypeId<'db>,
428 lookup_context: ImplLookupContextId<'db>,
429 stable_ptr: Option<SyntaxStablePtrId<'db>>,
430 ) -> TypeId<'db> {
431 let trait_type = concrete_trait_type.trait_type(self.db);
432 let impl_id = self.new_impl_var(
433 concrete_trait_type.concrete_trait(self.db),
434 stable_ptr,
435 lookup_context,
436 );
437 TypeLongId::ImplType(ImplTypeId::new(impl_id, trait_type, self.db)).intern(self.db)
438 }
439
440 fn infer_trait_constant(
443 &mut self,
444 concrete_trait_constant: ConcreteTraitConstantId<'db>,
445 lookup_context: ImplLookupContextId<'db>,
446 stable_ptr: Option<SyntaxStablePtrId<'db>>,
447 ) -> ImplConstantId<'db> {
448 let impl_id = self.new_impl_var(
449 concrete_trait_constant.concrete_trait(self.db),
450 stable_ptr,
451 lookup_context,
452 );
453
454 ImplConstantId::new(impl_id, concrete_trait_constant.trait_constant(self.db), self.db)
455 }
456
457 fn infer_trait_impl(
460 &mut self,
461 concrete_trait_impl: ConcreteTraitImplId<'db>,
462 lookup_context: ImplLookupContextId<'db>,
463 stable_ptr: Option<SyntaxStablePtrId<'db>>,
464 ) -> ImplImplId<'db> {
465 let impl_id = self.new_impl_var(
466 concrete_trait_impl.concrete_trait(self.db),
467 stable_ptr,
468 lookup_context,
469 );
470 let x = self.db;
471
472 ImplImplId::new(impl_id, concrete_trait_impl.trait_impl(x), x)
473 }
474}
475
476fn infer_concrete_trait_by_self<'r, 'db, 'mt>(
478 inference: &'r mut Inference<'db, 'mt>,
479 trait_function: TraitFunctionId<'db>,
480 self_ty: TypeId<'db>,
481 lookup_context: ImplLookupContextId<'db>,
482 stable_ptr: Option<SyntaxStablePtrId<'db>>,
483 inference_errors: &mut Vec<(TraitFunctionId<'db>, InferenceError<'db>)>,
484) -> Option<(ConcreteTraitId<'db>, usize)> {
485 let trait_id = trait_function.trait_id(inference.db);
486 let signature = inference.db.trait_function_signature(trait_function).ok()?;
487 let first_param = signature.params.first()?;
488 require(first_param.name.long(inference.db) == SELF_PARAM_KW)?;
489
490 let trait_generic_params = inference.db.trait_generic_params(trait_id).ok()?;
491 let trait_generic_args =
492 match inference.infer_generic_args(trait_generic_params, lookup_context, stable_ptr) {
493 Ok(generic_args) => generic_args,
494 Err(err_set) => {
495 if let Some(err) = inference.consume_error_without_reporting(err_set) {
496 inference_errors.push((trait_function, err));
497 }
498 return None;
499 }
500 };
501
502 let mut tmp_inference_data = inference.temporary_clone();
504 let mut tmp_inference = tmp_inference_data.inference(inference.db);
505 let function_generic_params =
506 tmp_inference.db.trait_function_generic_params(trait_function).ok()?;
507 let function_generic_args =
508 match tmp_inference.infer_generic_args(function_generic_params, lookup_context, stable_ptr) {
511 Ok(generic_args) => generic_args,
512 Err(err_set) => {
513 if let Some(err) = inference.consume_error_without_reporting(err_set) {
514 inference_errors.push((trait_function, err));
515 }
516 return None;
517 }
518 };
519
520 let trait_substitution = GenericSubstitution::new(trait_generic_params, &trait_generic_args);
521 let function_substitution =
522 GenericSubstitution::new(function_generic_params, &function_generic_args);
523 let substitution = trait_substitution.concat(function_substitution);
524
525 let fixed_param_ty = substitution.substitute(inference.db, first_param.ty).ok()?;
526 let (_, n_snapshots) = match inference.conform_ty_ex(self_ty, fixed_param_ty, true) {
527 Ok(conform) => conform,
528 Err(err_set) => {
529 if let Some(err) = inference.consume_error_without_reporting(err_set) {
530 inference_errors.push((trait_function, err));
531 }
532 return None;
533 }
534 };
535
536 let generic_args = inference.rewrite(trait_generic_args).no_err();
537
538 Some((ConcreteTraitLongId { trait_id, generic_args }.intern(inference.db), n_snapshots))
539}