Skip to main content

cairo_lang_semantic/expr/inference/
infers.rs

1use cairo_lang_defs::ids::{ImplAliasId, ImplDefId, TraitFunctionId};
2use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
3use cairo_lang_utils::{Intern, LookupIntern, extract_matches, require};
4use itertools::Itertools;
5
6use super::canonic::ResultNoErrEx;
7use super::conform::InferenceConform;
8use super::{Inference, InferenceError, InferenceResult};
9use crate::items::constant::ImplConstantId;
10use crate::items::functions::{GenericFunctionId, ImplGenericFunctionId};
11use crate::items::generics::GenericParamConst;
12use crate::items::imp::{
13    GeneratedImplLongId, ImplId, ImplImplId, ImplLongId, ImplLookupContext, UninferredImpl,
14};
15use crate::items::trt::{
16    ConcreteTraitConstantId, ConcreteTraitGenericFunctionId, ConcreteTraitImplId,
17    ConcreteTraitTypeId,
18};
19use crate::keyword::SELF_PARAM_KW;
20use crate::substitution::{GenericSubstitution, SemanticRewriter};
21use crate::types::ImplTypeId;
22use crate::{
23    ConcreteFunction, ConcreteImplLongId, ConcreteTraitId, ConcreteTraitLongId, FunctionId,
24    FunctionLongId, GenericArgumentId, GenericParam, TypeId, TypeLongId,
25};
26
27/// Functions for embedding generic semantic objects in an existing [Inference] object, by
28/// introducing new variables.
29pub trait InferenceEmbeddings {
30    fn infer_impl(
31        &mut self,
32        uninferred_impl: UninferredImpl,
33        concrete_trait_id: ConcreteTraitId,
34        lookup_context: &ImplLookupContext,
35        stable_ptr: Option<SyntaxStablePtrId>,
36    ) -> InferenceResult<ImplId>;
37    fn infer_impl_def(
38        &mut self,
39        impl_def_id: ImplDefId,
40        concrete_trait_id: ConcreteTraitId,
41        lookup_context: &ImplLookupContext,
42        stable_ptr: Option<SyntaxStablePtrId>,
43    ) -> InferenceResult<ImplId>;
44    fn infer_impl_alias(
45        &mut self,
46        impl_alias_id: ImplAliasId,
47        concrete_trait_id: ConcreteTraitId,
48        lookup_context: &ImplLookupContext,
49        stable_ptr: Option<SyntaxStablePtrId>,
50    ) -> InferenceResult<ImplId>;
51    fn infer_generic_assignment(
52        &mut self,
53        generic_params: &[GenericParam],
54        generic_args: &[GenericArgumentId],
55        expected_generic_args: &[GenericArgumentId],
56        lookup_context: &ImplLookupContext,
57        stable_ptr: Option<SyntaxStablePtrId>,
58    ) -> InferenceResult<Vec<GenericArgumentId>>;
59    fn infer_generic_args(
60        &mut self,
61        generic_params: &[GenericParam],
62        lookup_context: &ImplLookupContext,
63        stable_ptr: Option<SyntaxStablePtrId>,
64    ) -> InferenceResult<Vec<GenericArgumentId>>;
65    fn infer_concrete_trait_by_self(
66        &mut self,
67        trait_function: TraitFunctionId,
68        self_ty: TypeId,
69        lookup_context: &ImplLookupContext,
70        stable_ptr: Option<SyntaxStablePtrId>,
71        inference_error_cb: impl FnOnce(InferenceError),
72    ) -> Option<(ConcreteTraitId, usize)>;
73    fn infer_generic_arg(
74        &mut self,
75        param: &GenericParam,
76        lookup_context: ImplLookupContext,
77        stable_ptr: Option<SyntaxStablePtrId>,
78    ) -> InferenceResult<GenericArgumentId>;
79    fn infer_trait_function(
80        &mut self,
81        concrete_trait_function: ConcreteTraitGenericFunctionId,
82        lookup_context: &ImplLookupContext,
83        stable_ptr: Option<SyntaxStablePtrId>,
84    ) -> InferenceResult<FunctionId>;
85    fn infer_generic_function(
86        &mut self,
87        generic_function: GenericFunctionId,
88        lookup_context: &ImplLookupContext,
89        stable_ptr: Option<SyntaxStablePtrId>,
90    ) -> InferenceResult<FunctionId>;
91    fn infer_trait_generic_function(
92        &mut self,
93        concrete_trait_function: ConcreteTraitGenericFunctionId,
94        lookup_context: &ImplLookupContext,
95        stable_ptr: Option<SyntaxStablePtrId>,
96    ) -> ImplGenericFunctionId;
97    fn infer_trait_type(
98        &mut self,
99        concrete_trait_type: ConcreteTraitTypeId,
100        lookup_context: &ImplLookupContext,
101        stable_ptr: Option<SyntaxStablePtrId>,
102    ) -> TypeId;
103    fn infer_trait_constant(
104        &mut self,
105        concrete_trait_constant: ConcreteTraitConstantId,
106        lookup_context: &ImplLookupContext,
107        stable_ptr: Option<SyntaxStablePtrId>,
108    ) -> ImplConstantId;
109    fn infer_trait_impl(
110        &mut self,
111        concrete_trait_constant: ConcreteTraitImplId,
112        lookup_context: &ImplLookupContext,
113        stable_ptr: Option<SyntaxStablePtrId>,
114    ) -> ImplImplId;
115}
116
117impl InferenceEmbeddings for Inference<'_> {
118    /// Infers all the variables required to make an uninferred impl provide a concrete trait.
119    fn infer_impl(
120        &mut self,
121        uninferred_impl: UninferredImpl,
122        concrete_trait_id: ConcreteTraitId,
123        lookup_context: &ImplLookupContext,
124        stable_ptr: Option<SyntaxStablePtrId>,
125    ) -> InferenceResult<ImplId> {
126        let impl_id = match uninferred_impl {
127            UninferredImpl::Def(impl_def_id) => {
128                self.infer_impl_def(impl_def_id, concrete_trait_id, lookup_context, stable_ptr)?
129            }
130            UninferredImpl::ImplAlias(impl_alias_id) => {
131                self.infer_impl_alias(impl_alias_id, concrete_trait_id, lookup_context, stable_ptr)?
132            }
133            UninferredImpl::ImplImpl(impl_impl_id) => {
134                ImplLongId::ImplImpl(impl_impl_id).intern(self.db)
135            }
136            UninferredImpl::GenericParam(param_id) => {
137                let param = self
138                    .db
139                    .generic_param_semantic(param_id)
140                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
141                let param = extract_matches!(param, GenericParam::Impl);
142                let imp_concrete_trait_id = param.concrete_trait.unwrap();
143                self.conform_traits(concrete_trait_id, imp_concrete_trait_id)?;
144                ImplLongId::GenericParameter(param_id).intern(self.db)
145            }
146            UninferredImpl::GeneratedImpl(generated_impl) => {
147                let long_id = generated_impl.lookup_intern(self.db);
148
149                // Only making sure the args can be inferred - as they are unused later.
150                self.infer_generic_args(&long_id.generic_params[..], lookup_context, stable_ptr)?;
151
152                ImplLongId::GeneratedImpl(
153                    GeneratedImplLongId {
154                        concrete_trait: long_id.concrete_trait,
155                        generic_params: long_id.generic_params,
156                        impl_items: long_id.impl_items,
157                    }
158                    .intern(self.db),
159                )
160                .intern(self.db)
161            }
162        };
163        Ok(impl_id)
164    }
165
166    /// Infers all the variables required to make an impl (possibly with free generic params)
167    /// provide a concrete trait.
168    fn infer_impl_def(
169        &mut self,
170        impl_def_id: ImplDefId,
171        concrete_trait_id: ConcreteTraitId,
172        lookup_context: &ImplLookupContext,
173        stable_ptr: Option<SyntaxStablePtrId>,
174    ) -> InferenceResult<ImplId> {
175        let imp_generic_params = self
176            .db
177            .impl_def_generic_params(impl_def_id)
178            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
179        let imp_concrete_trait = self
180            .db
181            .impl_def_concrete_trait(impl_def_id)
182            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
183        if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
184            return Err(self.set_error(InferenceError::TraitMismatch {
185                trt0: imp_concrete_trait.trait_id(self.db),
186                trt1: concrete_trait_id.trait_id(self.db),
187            }));
188        }
189
190        let long_concrete_trait = concrete_trait_id.lookup_intern(self.db);
191        let long_imp_concrete_trait = imp_concrete_trait.lookup_intern(self.db);
192        let generic_args = self.infer_generic_assignment(
193            &imp_generic_params,
194            &long_imp_concrete_trait.generic_args,
195            &long_concrete_trait.generic_args,
196            lookup_context,
197            stable_ptr,
198        )?;
199        Ok(ImplLongId::Concrete(ConcreteImplLongId { impl_def_id, generic_args }.intern(self.db))
200            .intern(self.db))
201    }
202
203    /// Infers all the variables required to make an impl alias (possibly with free generic params)
204    /// provide a concrete trait.
205    fn infer_impl_alias(
206        &mut self,
207        impl_alias_id: ImplAliasId,
208        concrete_trait_id: ConcreteTraitId,
209        lookup_context: &ImplLookupContext,
210        stable_ptr: Option<SyntaxStablePtrId>,
211    ) -> InferenceResult<ImplId> {
212        let impl_alias_generic_params = self
213            .db
214            .impl_alias_generic_params(impl_alias_id)
215            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
216        let impl_id = self
217            .db
218            .impl_alias_resolved_impl(impl_alias_id)
219            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
220        let imp_concrete_trait = impl_id
221            .concrete_trait(self.db)
222            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
223        if imp_concrete_trait.trait_id(self.db) != concrete_trait_id.trait_id(self.db) {
224            return Err(self.set_error(InferenceError::TraitMismatch {
225                trt0: imp_concrete_trait.trait_id(self.db),
226                trt1: concrete_trait_id.trait_id(self.db),
227            }));
228        }
229
230        let long_concrete_trait = concrete_trait_id.lookup_intern(self.db);
231        let long_imp_concrete_trait = imp_concrete_trait.lookup_intern(self.db);
232        let generic_args = self.infer_generic_assignment(
233            &impl_alias_generic_params,
234            &long_imp_concrete_trait.generic_args,
235            &long_concrete_trait.generic_args,
236            lookup_context,
237            stable_ptr,
238        )?;
239
240        GenericSubstitution::new(&impl_alias_generic_params, &generic_args)
241            .substitute(self.db, impl_id)
242            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))
243    }
244
245    /// Chooses and assignment to generic_params s.t. generic_args will be substituted to
246    /// expected_generic_args.
247    /// Returns the generic_params assignment.
248    fn infer_generic_assignment(
249        &mut self,
250        generic_params: &[GenericParam],
251        generic_args: &[GenericArgumentId],
252        expected_generic_args: &[GenericArgumentId],
253        lookup_context: &ImplLookupContext,
254        stable_ptr: Option<SyntaxStablePtrId>,
255    ) -> InferenceResult<Vec<GenericArgumentId>> {
256        let new_generic_args =
257            self.infer_generic_args(generic_params, lookup_context, stable_ptr)?;
258        let substitution = GenericSubstitution::new(generic_params, &new_generic_args);
259        let generic_args = substitution
260            .substitute(self.db, generic_args.iter().copied().collect_vec())
261            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
262        self.conform_generic_args(&generic_args, expected_generic_args)?;
263        Ok(self.rewrite(new_generic_args).no_err())
264    }
265
266    /// Infers all generic_arguments given the parameters.
267    fn infer_generic_args(
268        &mut self,
269        generic_params: &[GenericParam],
270        lookup_context: &ImplLookupContext,
271        stable_ptr: Option<SyntaxStablePtrId>,
272    ) -> InferenceResult<Vec<GenericArgumentId>> {
273        let mut generic_args = vec![];
274        let mut substitution = GenericSubstitution::default();
275        for generic_param in generic_params {
276            let generic_param = substitution
277                .substitute(self.db, generic_param.clone())
278                .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
279            let generic_arg =
280                self.infer_generic_arg(&generic_param, lookup_context.clone(), stable_ptr)?;
281            generic_args.push(generic_arg);
282            substitution.insert(generic_param.id(), generic_arg);
283        }
284        Ok(generic_args)
285    }
286
287    /// Tries to infer a trait function as a method for `self_ty`.
288    /// Supports snapshot coercions.
289    ///
290    /// Returns the deduced type and the number of snapshots that need to be added to it.
291    ///
292    /// `inference_error_cb` is called for inference errors, but they are not reported here as
293    /// diagnostics. The caller has to make sure the diagnostics are reported appropriately.
294    fn infer_concrete_trait_by_self(
295        &mut self,
296        trait_function: TraitFunctionId,
297        self_ty: TypeId,
298        lookup_context: &ImplLookupContext,
299        stable_ptr: Option<SyntaxStablePtrId>,
300        inference_error_cb: impl FnOnce(InferenceError),
301    ) -> Option<(ConcreteTraitId, usize)> {
302        let trait_id = trait_function.trait_id(self.db);
303        let signature = self.db.trait_function_signature(trait_function).ok()?;
304        let first_param = signature.params.into_iter().next()?;
305        require(first_param.name == SELF_PARAM_KW)?;
306
307        let trait_generic_params = self.db.trait_generic_params(trait_id).ok()?;
308        let trait_generic_args =
309            match self.infer_generic_args(&trait_generic_params, lookup_context, stable_ptr) {
310                Ok(generic_args) => generic_args,
311                Err(err_set) => {
312                    if let Some(err) = self.consume_error_without_reporting(err_set) {
313                        inference_error_cb(err);
314                    }
315                    return None;
316                }
317            };
318
319        // TODO(yuval): Try to not temporary clone.
320        let mut tmp_inference_data = self.temporary_clone();
321        let mut tmp_inference = tmp_inference_data.inference(self.db);
322        let function_generic_params =
323            tmp_inference.db.trait_function_generic_params(trait_function).ok()?;
324        let function_generic_args =
325            // TODO(yuval): consider getting the substitution from inside `infer_generic_args`
326            // instead of creating it again here.
327            match tmp_inference.infer_generic_args(&function_generic_params, lookup_context, stable_ptr) {
328                Ok(generic_args) => generic_args,
329                Err(err_set) => {
330                    if let Some(err) = self.consume_error_without_reporting(err_set) {
331                        inference_error_cb(err);
332                    }
333                    return None;
334                }
335            };
336
337        let trait_substitution =
338            GenericSubstitution::new(&trait_generic_params, &trait_generic_args);
339        let function_substitution =
340            GenericSubstitution::new(&function_generic_params, &function_generic_args);
341        let substitution = trait_substitution.concat(function_substitution);
342
343        let fixed_param_ty = substitution.substitute(self.db, first_param.ty).ok()?;
344        let (_, n_snapshots) = match self.conform_ty_ex(self_ty, fixed_param_ty, true) {
345            Ok(conform) => conform,
346            Err(err_set) => {
347                if let Some(err) = self.consume_error_without_reporting(err_set) {
348                    inference_error_cb(err);
349                }
350                return None;
351            }
352        };
353
354        let generic_args = self.rewrite(trait_generic_args).no_err();
355
356        Some((ConcreteTraitLongId { trait_id, generic_args }.intern(self.db), n_snapshots))
357    }
358
359    /// Infers a generic argument to be passed as a generic parameter.
360    /// Allocates a new inference variable of the correct kind, and wraps in a generic argument.
361    fn infer_generic_arg(
362        &mut self,
363        param: &GenericParam,
364        lookup_context: ImplLookupContext,
365        stable_ptr: Option<SyntaxStablePtrId>,
366    ) -> InferenceResult<GenericArgumentId> {
367        match param {
368            GenericParam::Type(_) => Ok(GenericArgumentId::Type(self.new_type_var(stable_ptr))),
369            GenericParam::Impl(param) => {
370                let concrete_trait_id = param
371                    .concrete_trait
372                    .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
373                let impl_id = self.new_impl_var(concrete_trait_id, stable_ptr, lookup_context);
374                for (trait_ty, ty1) in param.type_constraints.iter() {
375                    let ty0 = self.reduce_impl_ty(ImplTypeId::new(impl_id, *trait_ty, self.db))?;
376                    // Conforming the type will always work as the impl is a new inference variable.
377                    self.conform_ty(ty0, *ty1).ok();
378                }
379                Ok(GenericArgumentId::Impl(impl_id))
380            }
381            GenericParam::Const(GenericParamConst { ty, .. }) => {
382                Ok(GenericArgumentId::Constant(self.new_const_var(stable_ptr, *ty)))
383            }
384            GenericParam::NegImpl(_) => Ok(GenericArgumentId::NegImpl),
385        }
386    }
387
388    /// Infers the impl to be substituted instead of a trait for a given trait function,
389    /// and the generic arguments to be passed to the function.
390    /// Returns the resulting impl function.
391    fn infer_trait_function(
392        &mut self,
393        concrete_trait_function: ConcreteTraitGenericFunctionId,
394        lookup_context: &ImplLookupContext,
395        stable_ptr: Option<SyntaxStablePtrId>,
396    ) -> InferenceResult<FunctionId> {
397        let generic_function = GenericFunctionId::Impl(self.infer_trait_generic_function(
398            concrete_trait_function,
399            lookup_context,
400            stable_ptr,
401        ));
402        self.infer_generic_function(generic_function, lookup_context, stable_ptr)
403    }
404
405    /// Infers generic arguments to be passed to a generic function.
406    /// Returns the resulting specialized function.
407    fn infer_generic_function(
408        &mut self,
409        generic_function: GenericFunctionId,
410        lookup_context: &ImplLookupContext,
411        stable_ptr: Option<SyntaxStablePtrId>,
412    ) -> InferenceResult<FunctionId> {
413        let generic_params = generic_function
414            .generic_params(self.db)
415            .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?;
416        let generic_args = self.infer_generic_args(&generic_params, lookup_context, stable_ptr)?;
417        Ok(FunctionLongId { function: ConcreteFunction { generic_function, generic_args } }
418            .intern(self.db))
419    }
420
421    /// Infers the impl to be substituted instead of a trait for a given trait function.
422    /// Returns the resulting impl generic function.
423    fn infer_trait_generic_function(
424        &mut self,
425        concrete_trait_function: ConcreteTraitGenericFunctionId,
426        lookup_context: &ImplLookupContext,
427        stable_ptr: Option<SyntaxStablePtrId>,
428    ) -> ImplGenericFunctionId {
429        let impl_id = self.new_impl_var(
430            concrete_trait_function.concrete_trait(self.db),
431            stable_ptr,
432            lookup_context.clone(),
433        );
434        ImplGenericFunctionId { impl_id, function: concrete_trait_function.trait_function(self.db) }
435    }
436
437    /// Infers the impl to be substituted instead of a trait for a given trait type.
438    /// Returns the resulting impl type.
439    fn infer_trait_type(
440        &mut self,
441        concrete_trait_type: ConcreteTraitTypeId,
442        lookup_context: &ImplLookupContext,
443        stable_ptr: Option<SyntaxStablePtrId>,
444    ) -> TypeId {
445        let impl_id = self.new_impl_var(
446            concrete_trait_type.concrete_trait(self.db),
447            stable_ptr,
448            lookup_context.clone(),
449        );
450        TypeLongId::ImplType(ImplTypeId::new(
451            impl_id,
452            concrete_trait_type.trait_type(self.db),
453            self.db,
454        ))
455        .intern(self.db)
456    }
457
458    /// Infers the impl to be substituted instead of a trait for a given trait constant.
459    /// Returns the resulting impl constant.
460    fn infer_trait_constant(
461        &mut self,
462        concrete_trait_constant: ConcreteTraitConstantId,
463        lookup_context: &ImplLookupContext,
464        stable_ptr: Option<SyntaxStablePtrId>,
465    ) -> ImplConstantId {
466        let impl_id = self.new_impl_var(
467            concrete_trait_constant.concrete_trait(self.db),
468            stable_ptr,
469            lookup_context.clone(),
470        );
471
472        ImplConstantId::new(impl_id, concrete_trait_constant.trait_constant(self.db), self.db)
473    }
474
475    /// Infers the impl to be substituted instead of a trait for a given trait impl.
476    /// Returns the resulting impl impl.
477    fn infer_trait_impl(
478        &mut self,
479        concrete_trait_impl: ConcreteTraitImplId,
480        lookup_context: &ImplLookupContext,
481        stable_ptr: Option<SyntaxStablePtrId>,
482    ) -> ImplImplId {
483        let impl_id = self.new_impl_var(
484            concrete_trait_impl.concrete_trait(self.db),
485            stable_ptr,
486            lookup_context.clone(),
487        );
488
489        ImplImplId::new(impl_id, concrete_trait_impl.trait_impl(self.db), self.db)
490    }
491}