1use cairo_lang_defs::ids::{ImplAliasId, ImplDefId, TraitFunctionId};
2use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
3use cairo_lang_utils::{Intern, LookupIntern, extract_matches, require};
4use itertools::Itertools;
5
6use super::canonic::ResultNoErrEx;
7use super::conform::InferenceConform;
8use super::{Inference, InferenceError, InferenceResult};
9use crate::items::constant::ImplConstantId;
10use crate::items::functions::{GenericFunctionId, ImplGenericFunctionId};
11use crate::items::generics::GenericParamConst;
12use crate::items::imp::{
13 GeneratedImplLongId, ImplId, ImplImplId, ImplLongId, ImplLookupContext, UninferredImpl,
14};
15use crate::items::trt::{
16 ConcreteTraitConstantId, ConcreteTraitGenericFunctionId, ConcreteTraitImplId,
17 ConcreteTraitTypeId,
18};
19use crate::keyword::SELF_PARAM_KW;
20use crate::substitution::{GenericSubstitution, SemanticRewriter};
21use crate::types::ImplTypeId;
22use crate::{
23 ConcreteFunction, ConcreteImplLongId, ConcreteTraitId, ConcreteTraitLongId, FunctionId,
24 FunctionLongId, GenericArgumentId, GenericParam, TypeId, TypeLongId,
25};
26
27pub trait InferenceEmbeddings {
30 fn infer_impl(
31 &mut self,
32 uninferred_impl: UninferredImpl,
33 concrete_trait_id: ConcreteTraitId,
34 lookup_context: &ImplLookupContext,
35 stable_ptr: Option<SyntaxStablePtrId>,
36 ) -> InferenceResult<ImplId>;
37 fn infer_impl_def(
38 &mut self,
39 impl_def_id: ImplDefId,
40 concrete_trait_id: ConcreteTraitId,
41 lookup_context: &ImplLookupContext,
42 stable_ptr: Option<SyntaxStablePtrId>,
43 ) -> InferenceResult<ImplId>;
44 fn infer_impl_alias(
45 &mut self,
46 impl_alias_id: ImplAliasId,
47 concrete_trait_id: ConcreteTraitId,
48 lookup_context: &ImplLookupContext,
49 stable_ptr: Option<SyntaxStablePtrId>,
50 ) -> InferenceResult<ImplId>;
51 fn infer_generic_assignment(
52 &mut self,
53 generic_params: &[GenericParam],
54 generic_args: &[GenericArgumentId],
55 expected_generic_args: &[GenericArgumentId],
56 lookup_context: &ImplLookupContext,
57 stable_ptr: Option<SyntaxStablePtrId>,
58 ) -> InferenceResult<Vec<GenericArgumentId>>;
59 fn infer_generic_args(
60 &mut self,
61 generic_params: &[GenericParam],
62 lookup_context: &ImplLookupContext,
63 stable_ptr: Option<SyntaxStablePtrId>,
64 ) -> InferenceResult<Vec<GenericArgumentId>>;
65 fn infer_concrete_trait_by_self(
66 &mut self,
67 trait_function: TraitFunctionId,
68 self_ty: TypeId,
69 lookup_context: &ImplLookupContext,
70 stable_ptr: Option<SyntaxStablePtrId>,
71 inference_error_cb: impl FnOnce(InferenceError),
72 ) -> Option<(ConcreteTraitId, usize)>;
73 fn infer_generic_arg(
74 &mut self,
75 param: &GenericParam,
76 lookup_context: ImplLookupContext,
77 stable_ptr: Option<SyntaxStablePtrId>,
78 ) -> InferenceResult<GenericArgumentId>;
79 fn infer_trait_function(
80 &mut self,
81 concrete_trait_function: ConcreteTraitGenericFunctionId,
82 lookup_context: &ImplLookupContext,
83 stable_ptr: Option<SyntaxStablePtrId>,
84 ) -> InferenceResult<FunctionId>;
85 fn infer_generic_function(
86 &mut self,
87 generic_function: GenericFunctionId,
88 lookup_context: &ImplLookupContext,
89 stable_ptr: Option<SyntaxStablePtrId>,
90 ) -> InferenceResult<FunctionId>;
91 fn infer_trait_generic_function(
92 &mut self,
93 concrete_trait_function: ConcreteTraitGenericFunctionId,
94 lookup_context: &ImplLookupContext,
95 stable_ptr: Option<SyntaxStablePtrId>,
96 ) -> ImplGenericFunctionId;
97 fn infer_trait_type(
98 &mut self,
99 concrete_trait_type: ConcreteTraitTypeId,
100 lookup_context: &ImplLookupContext,
101 stable_ptr: Option<SyntaxStablePtrId>,
102 ) -> TypeId;
103 fn infer_trait_constant(
104 &mut self,
105 concrete_trait_constant: ConcreteTraitConstantId,
106 lookup_context: &ImplLookupContext,
107 stable_ptr: Option<SyntaxStablePtrId>,
108 ) -> ImplConstantId;
109 fn infer_trait_impl(
110 &mut self,
111 concrete_trait_constant: ConcreteTraitImplId,
112 lookup_context: &ImplLookupContext,
113 stable_ptr: Option<SyntaxStablePtrId>,
114 ) -> ImplImplId;
115}
116
117impl InferenceEmbeddings for Inference<'_> {
118 fn infer_impl(
120 &mut self,
121 uninferred_impl: UninferredImpl,
122 concrete_trait_id: ConcreteTraitId,
123 lookup_context: &ImplLookupContext,
124 stable_ptr: Option<SyntaxStablePtrId>,
125 ) -> InferenceResult<ImplId> {
126 let impl_id = match uninferred_impl {
127 UninferredImpl::Def(impl_def_id) => {
128 self.infer_impl_def(impl_def_id, concrete_trait_id, lookup_context, stable_ptr)?
129 }
130 UninferredImpl::ImplAlias(impl_alias_id) => {
131 self.infer_impl_alias(impl_alias_id, concrete_trait_id, lookup_context, stable_ptr)?
132 }
133 UninferredImpl::ImplImpl(impl_impl_id) => {
134 ImplLongId::ImplImpl(impl_impl_id).intern(self.db)
135 }
136 UninferredImpl::GenericParam(param_id) => {
137 let param = self
138 .db
139 .generic_param_semantic(param_id)
140 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
141 let param = extract_matches!(param, GenericParam::Impl);
142 let imp_concrete_trait_id = param.concrete_trait.unwrap();
143 self.conform_traits(concrete_trait_id, imp_concrete_trait_id)?;
144 ImplLongId::GenericParameter(param_id).intern(self.db)
145 }
146 UninferredImpl::GeneratedImpl(generated_impl) => {
147 let long_id = generated_impl.lookup_intern(self.db);
148
149 self.infer_generic_args(&long_id.generic_params[..], lookup_context, stable_ptr)?;
151
152 ImplLongId::GeneratedImpl(
153 GeneratedImplLongId {
154 concrete_trait: long_id.concrete_trait,
155 generic_params: long_id.generic_params,
156 impl_items: long_id.impl_items,
157 }
158 .intern(self.db),
159 )
160 .intern(self.db)
161 }
162 };
163 Ok(impl_id)
164 }
165
166 fn infer_impl_def(
169 &mut self,
170 impl_def_id: ImplDefId,
171 concrete_trait_id: ConcreteTraitId,
172 lookup_context: &ImplLookupContext,
173 stable_ptr: Option<SyntaxStablePtrId>,
174 ) -> InferenceResult<ImplId> {
175 let imp_generic_params = self
176 .db
177 .impl_def_generic_params(impl_def_id)
178 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
179 let imp_concrete_trait = self
180 .db
181 .impl_def_concrete_trait(impl_def_id)
182 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
183 if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
184 return Err(self.set_error(InferenceError::TraitMismatch {
185 trt0: imp_concrete_trait.trait_id(self.db),
186 trt1: concrete_trait_id.trait_id(self.db),
187 }));
188 }
189
190 let long_concrete_trait = concrete_trait_id.lookup_intern(self.db);
191 let long_imp_concrete_trait = imp_concrete_trait.lookup_intern(self.db);
192 let generic_args = self.infer_generic_assignment(
193 &imp_generic_params,
194 &long_imp_concrete_trait.generic_args,
195 &long_concrete_trait.generic_args,
196 lookup_context,
197 stable_ptr,
198 )?;
199 Ok(ImplLongId::Concrete(ConcreteImplLongId { impl_def_id, generic_args }.intern(self.db))
200 .intern(self.db))
201 }
202
203 fn infer_impl_alias(
206 &mut self,
207 impl_alias_id: ImplAliasId,
208 concrete_trait_id: ConcreteTraitId,
209 lookup_context: &ImplLookupContext,
210 stable_ptr: Option<SyntaxStablePtrId>,
211 ) -> InferenceResult<ImplId> {
212 let impl_alias_generic_params = self
213 .db
214 .impl_alias_generic_params(impl_alias_id)
215 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
216 let impl_id = self
217 .db
218 .impl_alias_resolved_impl(impl_alias_id)
219 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
220 let imp_concrete_trait = impl_id
221 .concrete_trait(self.db)
222 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
223 if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
224 return Err(self.set_error(InferenceError::TraitMismatch {
225 trt0: imp_concrete_trait.trait_id(self.db),
226 trt1: concrete_trait_id.trait_id(self.db),
227 }));
228 }
229
230 let long_concrete_trait = concrete_trait_id.lookup_intern(self.db);
231 let long_imp_concrete_trait = imp_concrete_trait.lookup_intern(self.db);
232 let generic_args = self.infer_generic_assignment(
233 &impl_alias_generic_params,
234 &long_imp_concrete_trait.generic_args,
235 &long_concrete_trait.generic_args,
236 lookup_context,
237 stable_ptr,
238 )?;
239
240 GenericSubstitution::new(&impl_alias_generic_params, &generic_args)
241 .substitute(self.db, impl_id)
242 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))
243 }
244
245 fn infer_generic_assignment(
249 &mut self,
250 generic_params: &[GenericParam],
251 generic_args: &[GenericArgumentId],
252 expected_generic_args: &[GenericArgumentId],
253 lookup_context: &ImplLookupContext,
254 stable_ptr: Option<SyntaxStablePtrId>,
255 ) -> InferenceResult<Vec<GenericArgumentId>> {
256 let new_generic_args =
257 self.infer_generic_args(generic_params, lookup_context, stable_ptr)?;
258 let substitution = GenericSubstitution::new(generic_params, &new_generic_args);
259 let generic_args = substitution
260 .substitute(self.db, generic_args.iter().copied().collect_vec())
261 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
262 self.conform_generic_args(&generic_args, expected_generic_args)?;
263 Ok(self.rewrite(new_generic_args).no_err())
264 }
265
266 fn infer_generic_args(
268 &mut self,
269 generic_params: &[GenericParam],
270 lookup_context: &ImplLookupContext,
271 stable_ptr: Option<SyntaxStablePtrId>,
272 ) -> InferenceResult<Vec<GenericArgumentId>> {
273 let mut generic_args = vec![];
274 let mut substitution = GenericSubstitution::default();
275 for generic_param in generic_params {
276 let generic_param = substitution
277 .substitute(self.db, generic_param.clone())
278 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
279 let generic_arg =
280 self.infer_generic_arg(&generic_param, lookup_context.clone(), stable_ptr)?;
281 generic_args.push(generic_arg);
282 substitution.insert(generic_param.id(), generic_arg);
283 }
284 Ok(generic_args)
285 }
286
287 fn infer_concrete_trait_by_self(
295 &mut self,
296 trait_function: TraitFunctionId,
297 self_ty: TypeId,
298 lookup_context: &ImplLookupContext,
299 stable_ptr: Option<SyntaxStablePtrId>,
300 inference_error_cb: impl FnOnce(InferenceError),
301 ) -> Option<(ConcreteTraitId, usize)> {
302 let trait_id = trait_function.trait_id(self.db);
303 let signature = self.db.trait_function_signature(trait_function).ok()?;
304 let first_param = signature.params.into_iter().next()?;
305 require(first_param.name == SELF_PARAM_KW)?;
306
307 let trait_generic_params = self.db.trait_generic_params(trait_id).ok()?;
308 let trait_generic_args =
309 match self.infer_generic_args(&trait_generic_params, lookup_context, stable_ptr) {
310 Ok(generic_args) => generic_args,
311 Err(err_set) => {
312 if let Some(err) = self.consume_error_without_reporting(err_set) {
313 inference_error_cb(err);
314 }
315 return None;
316 }
317 };
318
319 let mut tmp_inference_data = self.temporary_clone();
321 let mut tmp_inference = tmp_inference_data.inference(self.db);
322 let function_generic_params =
323 tmp_inference.db.trait_function_generic_params(trait_function).ok()?;
324 let function_generic_args =
325 match tmp_inference.infer_generic_args(&function_generic_params, lookup_context, stable_ptr) {
328 Ok(generic_args) => generic_args,
329 Err(err_set) => {
330 if let Some(err) = self.consume_error_without_reporting(err_set) {
331 inference_error_cb(err);
332 }
333 return None;
334 }
335 };
336
337 let trait_substitution =
338 GenericSubstitution::new(&trait_generic_params, &trait_generic_args);
339 let function_substitution =
340 GenericSubstitution::new(&function_generic_params, &function_generic_args);
341 let substitution = trait_substitution.concat(function_substitution);
342
343 let fixed_param_ty = substitution.substitute(self.db, first_param.ty).ok()?;
344 let (_, n_snapshots) = match self.conform_ty_ex(self_ty, fixed_param_ty, true) {
345 Ok(conform) => conform,
346 Err(err_set) => {
347 if let Some(err) = self.consume_error_without_reporting(err_set) {
348 inference_error_cb(err);
349 }
350 return None;
351 }
352 };
353
354 let generic_args = self.rewrite(trait_generic_args).no_err();
355
356 Some((ConcreteTraitLongId { trait_id, generic_args }.intern(self.db), n_snapshots))
357 }
358
359 fn infer_generic_arg(
362 &mut self,
363 param: &GenericParam,
364 lookup_context: ImplLookupContext,
365 stable_ptr: Option<SyntaxStablePtrId>,
366 ) -> InferenceResult<GenericArgumentId> {
367 match param {
368 GenericParam::Type(_) => Ok(GenericArgumentId::Type(self.new_type_var(stable_ptr))),
369 GenericParam::Impl(param) => {
370 let concrete_trait_id = param
371 .concrete_trait
372 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
373 let impl_id = self.new_impl_var(concrete_trait_id, stable_ptr, lookup_context);
374 for (trait_ty, ty1) in param.type_constraints.iter() {
375 let ty0 = self.reduce_impl_ty(ImplTypeId::new(impl_id, *trait_ty, self.db))?;
376 self.conform_ty(ty0, *ty1).ok();
378 }
379 Ok(GenericArgumentId::Impl(impl_id))
380 }
381 GenericParam::Const(GenericParamConst { ty, .. }) => {
382 Ok(GenericArgumentId::Constant(self.new_const_var(stable_ptr, *ty)))
383 }
384 GenericParam::NegImpl(_) => Ok(GenericArgumentId::NegImpl),
385 }
386 }
387
388 fn infer_trait_function(
392 &mut self,
393 concrete_trait_function: ConcreteTraitGenericFunctionId,
394 lookup_context: &ImplLookupContext,
395 stable_ptr: Option<SyntaxStablePtrId>,
396 ) -> InferenceResult<FunctionId> {
397 let generic_function = GenericFunctionId::Impl(self.infer_trait_generic_function(
398 concrete_trait_function,
399 lookup_context,
400 stable_ptr,
401 ));
402 self.infer_generic_function(generic_function, lookup_context, stable_ptr)
403 }
404
405 fn infer_generic_function(
408 &mut self,
409 generic_function: GenericFunctionId,
410 lookup_context: &ImplLookupContext,
411 stable_ptr: Option<SyntaxStablePtrId>,
412 ) -> InferenceResult<FunctionId> {
413 let generic_params = generic_function
414 .generic_params(self.db)
415 .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
416 let generic_args = self.infer_generic_args(&generic_params, lookup_context, stable_ptr)?;
417 Ok(FunctionLongId { function: ConcreteFunction { generic_function, generic_args } }
418 .intern(self.db))
419 }
420
421 fn infer_trait_generic_function(
424 &mut self,
425 concrete_trait_function: ConcreteTraitGenericFunctionId,
426 lookup_context: &ImplLookupContext,
427 stable_ptr: Option<SyntaxStablePtrId>,
428 ) -> ImplGenericFunctionId {
429 let impl_id = self.new_impl_var(
430 concrete_trait_function.concrete_trait(self.db),
431 stable_ptr,
432 lookup_context.clone(),
433 );
434 ImplGenericFunctionId { impl_id, function: concrete_trait_function.trait_function(self.db) }
435 }
436
437 fn infer_trait_type(
440 &mut self,
441 concrete_trait_type: ConcreteTraitTypeId,
442 lookup_context: &ImplLookupContext,
443 stable_ptr: Option<SyntaxStablePtrId>,
444 ) -> TypeId {
445 let impl_id = self.new_impl_var(
446 concrete_trait_type.concrete_trait(self.db),
447 stable_ptr,
448 lookup_context.clone(),
449 );
450 TypeLongId::ImplType(ImplTypeId::new(
451 impl_id,
452 concrete_trait_type.trait_type(self.db),
453 self.db,
454 ))
455 .intern(self.db)
456 }
457
458 fn infer_trait_constant(
461 &mut self,
462 concrete_trait_constant: ConcreteTraitConstantId,
463 lookup_context: &ImplLookupContext,
464 stable_ptr: Option<SyntaxStablePtrId>,
465 ) -> ImplConstantId {
466 let impl_id = self.new_impl_var(
467 concrete_trait_constant.concrete_trait(self.db),
468 stable_ptr,
469 lookup_context.clone(),
470 );
471
472 ImplConstantId::new(impl_id, concrete_trait_constant.trait_constant(self.db), self.db)
473 }
474
475 fn infer_trait_impl(
478 &mut self,
479 concrete_trait_impl: ConcreteTraitImplId,
480 lookup_context: &ImplLookupContext,
481 stable_ptr: Option<SyntaxStablePtrId>,
482 ) -> ImplImplId {
483 let impl_id = self.new_impl_var(
484 concrete_trait_impl.concrete_trait(self.db),
485 stable_ptr,
486 lookup_context.clone(),
487 );
488
489 ImplImplId::new(impl_id, concrete_trait_impl.trait_impl(self.db), self.db)
490 }
491}