cairo_lang_semantic/expr/inference/
solver.rs1use cairo_lang_debug::DebugWithDb;
2use cairo_lang_defs::ids::LanguageElementId;
3use cairo_lang_proc_macros::SemanticObject;
4use cairo_lang_utils::LookupIntern;
5use itertools::Itertools;
6
7use super::canonic::{CanonicalImpl, CanonicalMapping, CanonicalTrait, MapperError, ResultNoErrEx};
8use super::conform::InferenceConform;
9use super::infers::InferenceEmbeddings;
10use super::{
11 ImplVarTraitItemMappings, InferenceData, InferenceError, InferenceId, InferenceResult,
12 InferenceVar, LocalImplVarId,
13};
14use crate::db::SemanticGroup;
15use crate::items::constant::ImplConstantId;
16use crate::items::imp::{
17 ImplId, ImplImplId, ImplLookupContext, UninferredImpl, find_candidates_at_context,
18 find_closure_generated_candidate,
19};
20use crate::substitution::SemanticRewriter;
21use crate::types::ImplTypeId;
22use crate::{ConcreteTraitId, GenericArgumentId, TypeId, TypeLongId};
23
24#[derive(Clone, PartialEq, Eq, Debug)]
26pub enum SolutionSet<T> {
27 None,
28 Unique(T),
29 Ambiguous(Ambiguity),
30}
31
32#[derive(Clone, Debug, Eq, Hash, PartialEq, SemanticObject)]
34pub enum Ambiguity {
35 MultipleImplsFound {
36 concrete_trait_id: ConcreteTraitId,
37 impls: Vec<ImplId>,
38 },
39 FreeVariable {
40 impl_id: ImplId,
41 #[dont_rewrite]
42 var: InferenceVar,
43 },
44 WillNotInfer(ConcreteTraitId),
45 NegativeImplWithUnresolvedGenericArgs {
46 impl_id: ImplId,
47 ty: TypeId,
48 },
49}
50impl Ambiguity {
51 pub fn format(&self, db: &(dyn SemanticGroup + 'static)) -> String {
52 match self {
53 Ambiguity::MultipleImplsFound { concrete_trait_id, impls } => {
54 let impls_str =
55 impls.iter().map(|imp| format!("`{}`", imp.format(db.upcast()))).join(", ");
56 format!(
57 "Trait `{:?}` has multiple implementations, in: {impls_str}",
58 concrete_trait_id.debug(db)
59 )
60 }
61 Ambiguity::FreeVariable { impl_id, var: _ } => {
62 format!("Candidate impl {:?} has an unused generic parameter.", impl_id.debug(db),)
63 }
64 Ambiguity::WillNotInfer(concrete_trait_id) => {
65 format!(
66 "Cannot infer trait {:?}. First generic argument must be known.",
67 concrete_trait_id.debug(db)
68 )
69 }
70 Ambiguity::NegativeImplWithUnresolvedGenericArgs { impl_id, ty } => format!(
71 "Cannot infer negative impl in `{}` as it contains the unresolved type `{}`",
72 impl_id.format(db),
73 ty.format(db)
74 ),
75 }
76 }
77}
78
79pub fn canonic_trait_solutions(
82 db: &dyn SemanticGroup,
83 canonical_trait: CanonicalTrait,
84 lookup_context: ImplLookupContext,
85) -> Result<SolutionSet<CanonicalImpl>, InferenceError> {
86 let mut concrete_trait_id = canonical_trait.id;
87 if !concrete_trait_id.is_fully_concrete(db) {
90 let mut solver = Solver::new(db, canonical_trait, lookup_context.clone());
91 match solver.solution_set(db) {
92 SolutionSet::None => {}
93 SolutionSet::Unique(imp) => {
94 concrete_trait_id =
95 imp.0.concrete_trait(db).expect("A solved impl must have a concrete trait");
96 }
97 SolutionSet::Ambiguous(ambiguity) => {
98 return Ok(SolutionSet::Ambiguous(ambiguity));
99 }
100 }
101 }
102 let mut solver = Solver::new(
104 db,
105 CanonicalTrait { id: concrete_trait_id, mappings: ImplVarTraitItemMappings::default() },
106 lookup_context,
107 );
108
109 Ok(solver.solution_set(db))
110}
111
112pub fn canonic_trait_solutions_cycle(
114 _db: &dyn SemanticGroup,
115 _cycle: &salsa::Cycle,
116 _canonical_trait: &CanonicalTrait,
117 _lookup_context: &ImplLookupContext,
118) -> Result<SolutionSet<CanonicalImpl>, InferenceError> {
119 Err(InferenceError::Cycle(InferenceVar::Impl(LocalImplVarId(0))))
120}
121
122pub fn enrich_lookup_context(
124 db: &dyn SemanticGroup,
125 concrete_trait_id: ConcreteTraitId,
126 lookup_context: &mut ImplLookupContext,
127) {
128 lookup_context.insert_module(concrete_trait_id.trait_id(db).module_file_id(db.upcast()).0);
129 let generic_args = concrete_trait_id.generic_args(db);
130 for generic_arg in &generic_args {
132 if let GenericArgumentId::Type(ty) = generic_arg {
133 match ty.lookup_intern(db) {
134 TypeLongId::Concrete(concrete) => {
135 lookup_context
136 .insert_module(concrete.generic_type(db).module_file_id(db.upcast()).0);
137 }
138 TypeLongId::Coupon(function_id) => {
139 if let Some(module_file_id) =
140 function_id.get_concrete(db).generic_function.module_file_id(db)
141 {
142 lookup_context.insert_module(module_file_id.0);
143 }
144 }
145 TypeLongId::ImplType(impl_type_id) => {
146 lookup_context.insert_impl(impl_type_id.impl_id());
147 }
148 _ => (),
149 }
150 }
151 }
152}
153
154#[derive(Debug)]
156pub struct Solver {
157 pub canonical_trait: CanonicalTrait,
158 pub lookup_context: ImplLookupContext,
159 candidate_solvers: Vec<CandidateSolver>,
160}
161impl Solver {
162 fn new(
163 db: &dyn SemanticGroup,
164 canonical_trait: CanonicalTrait,
165 lookup_context: ImplLookupContext,
166 ) -> Self {
167 let filter = canonical_trait.id.filter(db);
168 let mut candidates =
169 find_candidates_at_context(db, &lookup_context, &filter).unwrap_or_default();
170 find_closure_generated_candidate(db, canonical_trait.id)
171 .map(|candidate| candidates.insert(candidate));
172 let candidate_solvers = candidates
173 .into_iter()
174 .filter_map(|candidate| {
175 CandidateSolver::new(db, &canonical_trait, candidate, &lookup_context).ok()
176 })
177 .collect();
178
179 Self { canonical_trait, lookup_context, candidate_solvers }
180 }
181
182 pub fn solution_set(&mut self, db: &dyn SemanticGroup) -> SolutionSet<CanonicalImpl> {
183 let mut unique_solution: Option<CanonicalImpl> = None;
184 for candidate_solver in &mut self.candidate_solvers {
185 let Ok(candidate_solution_set) = candidate_solver.solution_set(db) else {
186 continue;
187 };
188
189 let candidate_solution = match candidate_solution_set {
190 SolutionSet::None => continue,
191 SolutionSet::Unique(candidate_solution) => candidate_solution,
192 SolutionSet::Ambiguous(ambiguity) => return SolutionSet::Ambiguous(ambiguity),
193 };
194 if let Some(unique_solution) = unique_solution {
195 if unique_solution.0 != candidate_solution.0 {
199 return SolutionSet::Ambiguous(Ambiguity::MultipleImplsFound {
200 concrete_trait_id: self.canonical_trait.id,
201 impls: vec![unique_solution.0, candidate_solution.0],
202 });
203 }
204 }
205 unique_solution = Some(candidate_solution);
206 }
207 unique_solution.map(SolutionSet::Unique).unwrap_or(SolutionSet::None)
208 }
209}
210
211#[derive(Debug)]
213pub struct CandidateSolver {
214 pub candidate: UninferredImpl,
215 inference_data: InferenceData,
216 canonical_embedding: CanonicalMapping,
217 candidate_impl: ImplId,
218 pub lookup_context: ImplLookupContext,
219}
220impl CandidateSolver {
221 fn new(
222 db: &dyn SemanticGroup,
223 canonical_trait: &CanonicalTrait,
224 candidate: UninferredImpl,
225 lookup_context: &ImplLookupContext,
226 ) -> InferenceResult<CandidateSolver> {
227 let mut inference_data: InferenceData = InferenceData::new(InferenceId::Canonical);
228 let mut inference = inference_data.inference(db);
229 let (canonical_trait, canonical_embedding) = canonical_trait.embed(&mut inference);
230
231 if let UninferredImpl::GeneratedImpl(imp) = candidate {
234 inference.conform_traits(imp.lookup_intern(db).concrete_trait, canonical_trait.id)?;
235 }
236
237 let mut lookup_context = lookup_context.clone();
239 lookup_context.insert_lookup_scope(db, &candidate);
240 let candidate_impl =
242 inference.infer_impl(candidate, canonical_trait.id, &lookup_context, None)?;
243 for (trait_type, ty) in canonical_trait.mappings.types.iter() {
244 let mapped_ty =
245 inference.reduce_impl_ty(ImplTypeId::new(candidate_impl, *trait_type, db))?;
246
247 inference.conform_ty(mapped_ty, *ty)?;
249 }
250 for (trait_const, const_id) in canonical_trait.mappings.constants.iter() {
251 let mapped_const_id = inference.reduce_impl_constant(ImplConstantId::new(
252 candidate_impl,
253 *trait_const,
254 db,
255 ))?;
256 inference.conform_const(mapped_const_id, *const_id)?;
258 }
259
260 for (trait_impl, impl_id) in canonical_trait.mappings.impls.iter() {
261 let mapped_impl_id =
262 inference.reduce_impl_impl(ImplImplId::new(candidate_impl, *trait_impl, db))?;
263 inference.conform_impl(mapped_impl_id, *impl_id)?;
265 }
266
267 Ok(CandidateSolver {
268 candidate,
269 inference_data,
270 canonical_embedding,
271 candidate_impl,
272 lookup_context,
273 })
274 }
275 fn solution_set(
276 &mut self,
277 db: &dyn SemanticGroup,
278 ) -> InferenceResult<SolutionSet<CanonicalImpl>> {
279 let mut inference = self.inference_data.inference(db);
280 let solution_set = inference.solution_set()?;
281 Ok(match solution_set {
282 SolutionSet::None => SolutionSet::None,
283 SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
284 SolutionSet::Unique(_) => {
285 let candidate_impl = inference.rewrite(self.candidate_impl).no_err();
286 match CanonicalImpl::canonicalize(db, candidate_impl, &self.canonical_embedding) {
287 Ok(canonical_impl) => {
288 inference.validate_neg_impls(&self.lookup_context, canonical_impl)?
289 }
290 Err(MapperError(var)) => {
291 return Ok(SolutionSet::Ambiguous(Ambiguity::FreeVariable {
292 impl_id: candidate_impl,
293 var,
294 }));
295 }
296 }
297 }
298 })
299 }
300}