cairo_lang_semantic/expr/inference/
solver.rs

1use cairo_lang_debug::DebugWithDb;
2use cairo_lang_defs::ids::LanguageElementId;
3use cairo_lang_proc_macros::SemanticObject;
4use cairo_lang_utils::LookupIntern;
5use itertools::Itertools;
6
7use super::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, MapperError, ResultNoErrEx};
8use super::conform::InferenceConform;
9use super::infers::InferenceEmbeddings;
10use super::{
11    ImplVarTraitItemMappings, InferenceData, InferenceError, InferenceId, InferenceResult,
12    InferenceVar, LocalImplVarId,
13};
14use crate::db::SemanticGroup;
15use crate::items::constant::ImplConstantId;
16use crate::items::imp::{
17    ImplId, ImplImplId, ImplLookupContext, UninferredImpl, find_candidates_at_context,
18    find_closure_generated_candidate,
19};
20use crate::substitution::SemanticRewriter;
21use crate::types::ImplTypeId;
22use crate::{ConcreteTraitId, GenericArgumentId, TypeId, TypeLongId};
23
24/// A generic solution set for an inference constraint system.
25#[derive(Clone, PartialEq, Eq, Debug)]
26pub enum SolutionSet<T> {
27    None,
28    Unique(T),
29    Ambiguous(Ambiguity),
30}
31
32/// Describes the kinds of inference ambiguities.
33#[derive(Clone, Debug, Eq, Hash, PartialEq, SemanticObject)]
34pub enum Ambiguity {
35    MultipleImplsFound {
36        concrete_trait_id: ConcreteTraitId,
37        impls: Vec<ImplId>,
38    },
39    FreeVariable {
40        impl_id: ImplId,
41        #[dont_rewrite]
42        var: InferenceVar,
43    },
44    WillNotInfer(ConcreteTraitId),
45    NegativeImplWithUnresolvedGenericArgs {
46        impl_id: ImplId,
47        ty: TypeId,
48    },
49}
50impl Ambiguity {
51    pub fn format(&self, db: &(dyn SemanticGroup + 'static)) -> String {
52        match self {
53            Ambiguity::MultipleImplsFound { concrete_trait_id, impls } => {
54                let impls_str =
55                    impls.iter().map(|imp| format!("`{}`", imp.format(db.upcast()))).join(", ");
56                format!(
57                    "Trait `{:?}` has multiple implementations, in: {impls_str}",
58                    concrete_trait_id.debug(db)
59                )
60            }
61            Ambiguity::FreeVariable { impl_id, var: _ } => {
62                format!("Candidate impl {:?} has an unused generic parameter.", impl_id.debug(db),)
63            }
64            Ambiguity::WillNotInfer(concrete_trait_id) => {
65                format!(
66                    "Cannot infer trait {:?}. First generic argument must be known.",
67                    concrete_trait_id.debug(db)
68                )
69            }
70            Ambiguity::NegativeImplWithUnresolvedGenericArgs { impl_id, ty } => format!(
71                "Cannot infer negative impl in `{}` as it contains the unresolved type `{}`",
72                impl_id.format(db),
73                ty.format(db)
74            ),
75        }
76    }
77}
78
79/// Query implementation of [SemanticGroup::canonic_trait_solutions].
80/// Assumes the lookup context is already enriched by [enrich_lookup_context].
81pub fn canonic_trait_solutions(
82    db: &dyn SemanticGroup,
83    canonical_trait: CanonicalTrait,
84    lookup_context: ImplLookupContext,
85) -> Result<SolutionSet<CanonicalImpl>, InferenceError> {
86    let mut concrete_trait_id = canonical_trait.id;
87    // If the trait is not fully concrete, we might be able to use the trait's items to find a
88    // more concrete trait.
89    if !concrete_trait_id.is_fully_concrete(db) {
90        let mut solver = Solver::new(db, canonical_trait, lookup_context.clone());
91        match solver.solution_set(db) {
92            SolutionSet::None => {}
93            SolutionSet::Unique(imp) => {
94                concrete_trait_id =
95                    imp.0.concrete_trait(db).expect("A solved impl must have a concrete trait");
96            }
97            SolutionSet::Ambiguous(ambiguity) => {
98                return Ok(SolutionSet::Ambiguous(ambiguity));
99            }
100        }
101    }
102    // Solve the trait without the trait items, so we'd be able to find conflicting impls.
103    let mut solver = Solver::new(
104        db,
105        CanonicalTrait { id: concrete_trait_id, mappings: ImplVarTraitItemMappings::default() },
106        lookup_context,
107    );
108
109    Ok(solver.solution_set(db))
110}
111
112/// Cycle handling for [canonic_trait_solutions].
113pub fn canonic_trait_solutions_cycle(
114    _db: &dyn SemanticGroup,
115    _cycle: &salsa::Cycle,
116    _canonical_trait: &CanonicalTrait,
117    _lookup_context: &ImplLookupContext,
118) -> Result<SolutionSet<CanonicalImpl>, InferenceError> {
119    Err(InferenceError::Cycle(InferenceVar::Impl(LocalImplVarId(0))))
120}
121
122/// Adds the defining module of the trait and the generic arguments to the lookup context.
123pub fn enrich_lookup_context(
124    db: &dyn SemanticGroup,
125    concrete_trait_id: ConcreteTraitId,
126    lookup_context: &mut ImplLookupContext,
127) {
128    lookup_context.insert_module(concrete_trait_id.trait_id(db).module_file_id(db.upcast()).0);
129    let generic_args = concrete_trait_id.generic_args(db);
130    // Add the defining module of the generic args to the lookup.
131    for generic_arg in &generic_args {
132        if let GenericArgumentId::Type(ty) = generic_arg {
133            match ty.lookup_intern(db) {
134                TypeLongId::Concrete(concrete) => {
135                    lookup_context
136                        .insert_module(concrete.generic_type(db).module_file_id(db.upcast()).0);
137                }
138                TypeLongId::Coupon(function_id) => {
139                    if let Some(module_file_id) =
140                        function_id.get_concrete(db).generic_function.module_file_id(db)
141                    {
142                        lookup_context.insert_module(module_file_id.0);
143                    }
144                }
145                TypeLongId::ImplType(impl_type_id) => {
146                    lookup_context.insert_impl(impl_type_id.impl_id());
147                }
148                _ => (),
149            }
150        }
151    }
152}
153
154/// A canonical trait solver.
155#[derive(Debug)]
156pub struct Solver {
157    pub canonical_trait: CanonicalTrait,
158    pub lookup_context: ImplLookupContext,
159    candidate_solvers: Vec<CandidateSolver>,
160}
161impl Solver {
162    fn new(
163        db: &dyn SemanticGroup,
164        canonical_trait: CanonicalTrait,
165        lookup_context: ImplLookupContext,
166    ) -> Self {
167        let filter = canonical_trait.id.filter(db);
168        let mut candidates =
169            find_candidates_at_context(db, &lookup_context, &filter).unwrap_or_default();
170        find_closure_generated_candidate(db, canonical_trait.id)
171            .map(|candidate| candidates.insert(candidate));
172        let candidate_solvers = candidates
173            .into_iter()
174            .filter_map(|candidate| {
175                CandidateSolver::new(db, &canonical_trait, candidate, &lookup_context).ok()
176            })
177            .collect();
178
179        Self { canonical_trait, lookup_context, candidate_solvers }
180    }
181
182    pub fn solution_set(&mut self, db: &dyn SemanticGroup) -> SolutionSet<CanonicalImpl> {
183        let mut unique_solution: Option<CanonicalImpl> = None;
184        for candidate_solver in &mut self.candidate_solvers {
185            let Ok(candidate_solution_set) = candidate_solver.solution_set(db) else {
186                continue;
187            };
188
189            let candidate_solution = match candidate_solution_set {
190                SolutionSet::None => continue,
191                SolutionSet::Unique(candidate_solution) => candidate_solution,
192                SolutionSet::Ambiguous(ambiguity) => return SolutionSet::Ambiguous(ambiguity),
193            };
194            if let Some(unique_solution) = unique_solution {
195                // There might be multiple unique solutions from different candidates that are
196                // solved to the same impl id (e.g. finding it near the trait, and
197                // through an impl alias). This is valid.
198                if unique_solution.0 != candidate_solution.0 {
199                    return SolutionSet::Ambiguous(Ambiguity::MultipleImplsFound {
200                        concrete_trait_id: self.canonical_trait.id,
201                        impls: vec![unique_solution.0, candidate_solution.0],
202                    });
203                }
204            }
205            unique_solution = Some(candidate_solution);
206        }
207        unique_solution.map(SolutionSet::Unique).unwrap_or(SolutionSet::None)
208    }
209}
210
211/// A solver for a candidate to a canonical trait.
212#[derive(Debug)]
213pub struct CandidateSolver {
214    pub candidate: UninferredImpl,
215    inference_data: InferenceData,
216    canonical_embedding: CanonicalMapping,
217    candidate_impl: ImplId,
218    pub lookup_context: ImplLookupContext,
219}
220impl CandidateSolver {
221    fn new(
222        db: &dyn SemanticGroup,
223        canonical_trait: &CanonicalTrait,
224        candidate: UninferredImpl,
225        lookup_context: &ImplLookupContext,
226    ) -> InferenceResult<CandidateSolver> {
227        let mut inference_data: InferenceData = InferenceData::new(InferenceId::Canonical);
228        let mut inference = inference_data.inference(db);
229        let (canonical_trait, canonical_embedding) = canonical_trait.embed(&mut inference);
230
231        // If the closure params are not var free, we cannot infer the negative impl.
232        // We use the canonical trait concretize the closure params.
233        if let UninferredImpl::GeneratedImpl(imp) = candidate {
234            inference.conform_traits(imp.lookup_intern(db).concrete_trait, canonical_trait.id)?;
235        }
236
237        // Add the defining module of the candidate to the lookup.
238        let mut lookup_context = lookup_context.clone();
239        lookup_context.insert_lookup_scope(db, &candidate);
240        // Instantiate the candidate in the inference table.
241        let candidate_impl =
242            inference.infer_impl(candidate, canonical_trait.id, &lookup_context, None)?;
243        for (trait_type, ty) in canonical_trait.mappings.types.iter() {
244            let mapped_ty =
245                inference.reduce_impl_ty(ImplTypeId::new(candidate_impl, *trait_type, db))?;
246
247            // Conform the candidate's type to the trait's type.
248            inference.conform_ty(mapped_ty, *ty)?;
249        }
250        for (trait_const, const_id) in canonical_trait.mappings.constants.iter() {
251            let mapped_const_id = inference.reduce_impl_constant(ImplConstantId::new(
252                candidate_impl,
253                *trait_const,
254                db,
255            ))?;
256            // Conform the candidate's constant to the trait's constant.
257            inference.conform_const(mapped_const_id, *const_id)?;
258        }
259
260        for (trait_impl, impl_id) in canonical_trait.mappings.impls.iter() {
261            let mapped_impl_id =
262                inference.reduce_impl_impl(ImplImplId::new(candidate_impl, *trait_impl, db))?;
263            // Conform the candidate's impl to the trait's impl.
264            inference.conform_impl(mapped_impl_id, *impl_id)?;
265        }
266
267        Ok(CandidateSolver {
268            candidate,
269            inference_data,
270            canonical_embedding,
271            candidate_impl,
272            lookup_context,
273        })
274    }
275    fn solution_set(
276        &mut self,
277        db: &dyn SemanticGroup,
278    ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
279        let mut inference = self.inference_data.inference(db);
280        let solution_set = inference.solution_set()?;
281        Ok(match solution_set {
282            SolutionSet::None => SolutionSet::None,
283            SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
284            SolutionSet::Unique(_) => {
285                let candidate_impl = inference.rewrite(self.candidate_impl).no_err();
286                match CanonicalImpl::canonicalize(db, candidate_impl, &self.canonical_embedding) {
287                    Ok(canonical_impl) => {
288                        inference.validate_neg_impls(&self.lookup_context, canonical_impl)?
289                    }
290                    Err(MapperError(var)) => {
291                        return Ok(SolutionSet::Ambiguous(Ambiguity::FreeVariable {
292                            impl_id: candidate_impl,
293                            var,
294                        }));
295                    }
296                }
297            }
298        })
299    }
300}