Skip to main content

cairo_lang_semantic/expr/inference/
solver.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use cairo_lang_debug::DebugWithDb;
5use cairo_lang_defs::ids::{GenericParamId, LanguageElementId};
6use cairo_lang_proc_macros::SemanticObject;
7use cairo_lang_utils::Intern;
8use cairo_lang_utils::ordered_hash_map::Entry;
9use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
10use itertools::{Itertools, chain, zip_eq};
11use salsa::Database;
12
13use super::canonic::{CanonicalImpl, CanonicalTrait, MapperError, ResultNoErrEx};
14use super::conform::InferenceConform;
15use super::infers::InferenceEmbeddings;
16use super::{
17    ImplVarTraitItemMappings, InferenceData, InferenceError, InferenceId, InferenceResult,
18    InferenceVar, LocalImplVarId,
19};
20use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
21use crate::items::imp::{
22    ImplId, ImplImplId, ImplLongId, ImplLookupContext, ImplLookupContextId, ImplSemantic,
23    UninferredImpl, UninferredImplById, find_candidates_at_context,
24    find_closure_generated_candidate, find_integer_literal_generated_candidate,
25};
26use crate::items::trt::TraitSemantic;
27use crate::substitution::{GenericSubstitution, SemanticRewriter};
28use crate::types::{ImplTypeById, ImplTypeId};
29use crate::{
30    ConcreteImplLongId, ConcreteTraitId, GenericArgumentId, GenericParam, TypeId, TypeLongId,
31};
32
33/// A generic solution set for an inference constraint system.
34#[derive(Clone, PartialEq, Eq, Debug)]
35pub enum SolutionSet<'db, T> {
36    None,
37    Unique(T),
38    Ambiguous(Ambiguity<'db>),
39}
40
41// Somewhat taken from the salsa::Update derive macro.
42unsafe impl<'db, T: salsa::Update> salsa::Update for SolutionSet<'db, T> {
43    unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
44        let old_pointer = unsafe { &mut *old_pointer };
45        match (old_pointer, new_value) {
46            (SolutionSet::None, SolutionSet::None) => false,
47            (SolutionSet::Unique(u1), SolutionSet::Unique(u2)) => unsafe {
48                salsa::plumbing::UpdateDispatch::<T>::maybe_update(u1, u2)
49            },
50            (SolutionSet::Ambiguous(ambiguity), SolutionSet::Ambiguous(ambiguity2)) => unsafe {
51                salsa::plumbing::UpdateDispatch::<Ambiguity<'db>>::maybe_update(
52                    ambiguity, ambiguity2,
53                )
54            },
55            (old_pointer, new_value) => {
56                *old_pointer = new_value;
57                true
58            }
59        }
60    }
61}
62
63/// Describes the kinds of inference ambiguities.
64#[derive(Clone, Debug, Eq, Hash, PartialEq, SemanticObject, salsa::Update)]
65pub enum Ambiguity<'db> {
66    MultipleImplsFound {
67        concrete_trait_id: ConcreteTraitId<'db>,
68        impls: Vec<ImplId<'db>>,
69    },
70    FreeVariable {
71        impl_id: ImplId<'db>,
72        #[dont_rewrite]
73        var: InferenceVar,
74    },
75    WillNotInfer(ConcreteTraitId<'db>),
76    NegativeImplWithUnsupportedExtractedArgs(GenericArgumentId<'db>),
77    NegativeImplWithUnsupportedGenericParam(GenericParamId<'db>),
78}
79
80impl<'db> Ambiguity<'db> {
81    pub fn format(&self, db: &dyn Database) -> String {
82        match self {
83            Ambiguity::MultipleImplsFound { concrete_trait_id, impls } => {
84                let impls_str = impls.iter().map(|imp| format!("`{}`", imp.format(db))).join(", ");
85                format!(
86                    "Trait `{:?}` has multiple implementations, in: {impls_str}",
87                    concrete_trait_id.debug(db)
88                )
89            }
90            Ambiguity::FreeVariable { impl_id, var: _ } => {
91                format!("Candidate impl {:?} has an unused generic parameter.", impl_id.debug(db),)
92            }
93            Ambiguity::WillNotInfer(concrete_trait_id) => {
94                format!(
95                    "Cannot infer trait {:?}. First generic argument must be known.",
96                    concrete_trait_id.debug(db)
97                )
98            }
99            Ambiguity::NegativeImplWithUnsupportedExtractedArgs(garg) => {
100                format!("Negative impl has an unsupported generic argument {:?}.", garg.debug(db),)
101            }
102            Ambiguity::NegativeImplWithUnsupportedGenericParam(gparam) => {
103                format!("Negative impl has an unsupported generic argument {:?}.", gparam.debug(db),)
104            }
105        }
106    }
107}
108
109/// Implementation of [SemanticSolver::canonic_trait_solutions].
110/// Assumes the lookup context is already enriched by [enrich_lookup_context].
111pub fn canonic_trait_solutions<'db>(
112    db: &'db dyn Database,
113    canonical_trait: CanonicalTrait<'db>,
114    lookup_context: ImplLookupContextId<'db>,
115    impl_type_bounds: BTreeMap<ImplTypeById<'db>, TypeId<'db>>,
116) -> Result<SolutionSet<'db, CanonicalImpl<'db>>, InferenceError<'db>> {
117    let mut concrete_trait_id = canonical_trait.id;
118    let impl_type_bounds = Arc::new(impl_type_bounds);
119    // If the trait is not fully concrete, we might be able to use the trait's items to find a
120    // more concrete trait.
121    if !concrete_trait_id.is_fully_concrete(db) && !canonical_trait.mappings.is_empty() {
122        match solve_canonical_trait(db, canonical_trait, lookup_context, impl_type_bounds.clone()) {
123            SolutionSet::None => {}
124            SolutionSet::Unique(imp) => {
125                concrete_trait_id =
126                    imp.0.concrete_trait(db).expect("A solved impl must have a concrete trait");
127            }
128            SolutionSet::Ambiguous(ambiguity) => {
129                return Ok(SolutionSet::Ambiguous(ambiguity));
130            }
131        }
132    }
133    // Solve the trait without the trait items, so we'd be able to find conflicting impls.
134    Ok(solve_canonical_trait(
135        db,
136        CanonicalTrait { id: concrete_trait_id, mappings: ImplVarTraitItemMappings::default() },
137        lookup_context,
138        impl_type_bounds,
139    ))
140}
141
142/// Query implementation of [SemanticSolver::canonic_trait_solutions].
143/// Assumes the lookup context is already enriched by [enrich_lookup_context].
144#[salsa::tracked(cycle_result=canonic_trait_solutions_cycle)]
145pub fn canonic_trait_solutions_tracked<'db>(
146    db: &'db dyn Database,
147    canonical_trait: CanonicalTrait<'db>,
148    lookup_context: ImplLookupContextId<'db>,
149    impl_type_bounds: BTreeMap<ImplTypeById<'db>, TypeId<'db>>,
150) -> Result<SolutionSet<'db, CanonicalImpl<'db>>, InferenceError<'db>> {
151    canonic_trait_solutions(db, canonical_trait, lookup_context, impl_type_bounds)
152}
153
154/// Cycle handling for [canonic_trait_solutions].
155pub fn canonic_trait_solutions_cycle<'db>(
156    _db: &dyn Database,
157    _id: salsa::Id,
158    _canonical_trait: CanonicalTrait<'db>,
159    _lookup_context: ImplLookupContextId<'db>,
160    _impl_type_bounds: BTreeMap<ImplTypeById<'db>, TypeId<'db>>,
161) -> Result<SolutionSet<'db, CanonicalImpl<'db>>, InferenceError<'db>> {
162    Err(InferenceError::Cycle(InferenceVar::Impl(LocalImplVarId(0))))
163}
164
165/// Adds the defining module of the trait and the generic arguments to the lookup context.
166pub fn enrich_lookup_context<'db>(
167    db: &'db dyn Database,
168    concrete_trait_id: ConcreteTraitId<'db>,
169    lookup_context: &mut ImplLookupContext<'db>,
170) {
171    lookup_context.insert_module(concrete_trait_id.trait_id(db).parent_module(db), db);
172    let generic_args = concrete_trait_id.generic_args(db);
173    // Add the defining module of the generic args to the lookup.
174    for generic_arg in generic_args {
175        if let GenericArgumentId::Type(ty) = generic_arg {
176            enrich_lookup_context_with_ty(db, *ty, lookup_context);
177        }
178    }
179    lookup_context.strip_for_trait_id(db, concrete_trait_id.trait_id(db));
180}
181
182/// Adds the defining module of the type to the lookup context.
183pub fn enrich_lookup_context_with_ty<'db>(
184    db: &'db dyn Database,
185    ty: TypeId<'db>,
186    lookup_context: &mut ImplLookupContext<'db>,
187) {
188    match ty.long(db) {
189        TypeLongId::ImplType(impl_type_id) => {
190            lookup_context.insert_impl(impl_type_id.impl_id(), db);
191        }
192        long_ty => {
193            if let Some(module_id) = long_ty.module_id(db) {
194                lookup_context.insert_module(module_id, db);
195            }
196        }
197    }
198}
199
200/// Attempts to solve a `canonical_trait`. Will try to find candidates in the given
201/// `lookup_context`.
202fn solve_canonical_trait<'db>(
203    db: &'db dyn Database,
204    canonical_trait: CanonicalTrait<'db>,
205    lookup_context: ImplLookupContextId<'db>,
206    impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
207) -> SolutionSet<'db, CanonicalImpl<'db>> {
208    let filter = canonical_trait.id.filter(db);
209    // For `NumericLiteral` target types with Drop/Copy/Destruct/PanicDestruct, only the
210    // synthetic candidate should match. Otherwise concrete-headed candidates like `felt252Drop`
211    // would "claim" the integer literal and default its type via trait resolution.
212    let candidates = if let Some(integer_literal_candidate) =
213        find_integer_literal_generated_candidate(db, canonical_trait.id)
214    {
215        OrderedHashSet::from_iter([UninferredImplById(integer_literal_candidate)])
216    } else {
217        let mut set = find_candidates_at_context(db, lookup_context, filter).unwrap_or_default();
218        find_closure_generated_candidate(db, canonical_trait.id)
219            .map(|candidate| set.insert(UninferredImplById(candidate)));
220        set
221    };
222
223    let mut unique_solution: Option<CanonicalImpl<'_>> = None;
224    for candidate in candidates.into_iter() {
225        let Ok(candidate_solution_set) = solve_candidate(
226            db,
227            &canonical_trait,
228            candidate.0,
229            lookup_context,
230            impl_type_bounds.clone(),
231        ) else {
232            continue;
233        };
234
235        let candidate_solution = match candidate_solution_set {
236            SolutionSet::None => continue,
237            SolutionSet::Unique(candidate_solution) => candidate_solution,
238            SolutionSet::Ambiguous(ambiguity) => return SolutionSet::Ambiguous(ambiguity),
239        };
240        if let Some(unique_solution) = unique_solution {
241            // There might be multiple unique solutions from different candidates that are
242            // solved to the same impl id (e.g. finding it near the trait, and
243            // through an impl alias). This is valid.
244            if unique_solution.0 != candidate_solution.0 {
245                return SolutionSet::Ambiguous(Ambiguity::MultipleImplsFound {
246                    concrete_trait_id: canonical_trait.id,
247                    impls: vec![unique_solution.0, candidate_solution.0],
248                });
249            }
250        }
251        unique_solution = Some(candidate_solution);
252    }
253    unique_solution.map(SolutionSet::Unique).unwrap_or(SolutionSet::None)
254}
255
256/// Attempts to solve `candidate` as the requested `canonical_trait`.
257fn solve_candidate<'db>(
258    db: &'db dyn Database,
259    canonical_trait: &CanonicalTrait<'db>,
260    candidate: UninferredImpl<'db>,
261    lookup_context: ImplLookupContextId<'db>,
262    impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
263) -> InferenceResult<SolutionSet<'db, CanonicalImpl<'db>>> {
264    let Ok(candidate_concrete_trait) = candidate.concrete_trait(db) else {
265        return Err(super::ErrorSet);
266    };
267    // If the candidate is fully concrete, or it's a generic which is var free, there is nothing
268    // to substitute. A generic param may not be var free, if it contains impl types.
269    let candidate_final = matches!(candidate, UninferredImpl::GenericParam(_))
270        && candidate_concrete_trait.is_var_free(db)
271        || candidate_concrete_trait.is_fully_concrete(db);
272    let target_final = canonical_trait.id.is_var_free(db);
273    let mut lite_inference = LiteInference::new(db);
274    if candidate_final && target_final && candidate_concrete_trait != canonical_trait.id {
275        return Err(super::ErrorSet);
276    }
277
278    let mut res = lite_inference.can_conform_generic_args(
279        (candidate_concrete_trait.generic_args(db), candidate_final),
280        (canonical_trait.id.generic_args(db), target_final),
281    );
282
283    // If the candidate is a generic param, its trait is final and not substituted.
284    if matches!(candidate, UninferredImpl::GenericParam(_))
285        && !lite_inference.substitution.is_empty()
286    {
287        return Err(super::ErrorSet);
288    }
289
290    // If the trait has trait types, we default to using inference.
291    if res == CanConformResult::Accepted {
292        let Ok(trait_types) = db.trait_types(canonical_trait.id.trait_id(db)) else {
293            return Err(super::ErrorSet);
294        };
295        if !trait_types.is_empty() && !canonical_trait.mappings.types.is_empty() {
296            res = CanConformResult::InferenceRequired;
297        }
298    }
299
300    // Add the defining module of the candidate to the lookup.
301    let mut lookup_context = lookup_context.long(db).clone();
302    lookup_context.insert_lookup_scope(db, &candidate);
303    let lookup_context = lookup_context.intern(db);
304    if res == CanConformResult::Rejected {
305        return Err(super::ErrorSet);
306    } else if CanConformResult::Accepted == res {
307        match candidate {
308            UninferredImpl::Def(impl_def_id) => {
309                let imp_generic_params =
310                    db.impl_def_generic_params(impl_def_id).map_err(|_| super::ErrorSet)?;
311
312                match lite_inference.infer_generic_assignment(
313                    imp_generic_params,
314                    lookup_context,
315                    impl_type_bounds.clone(),
316                ) {
317                    Ok(SolutionSet::None) => {
318                        return Ok(SolutionSet::None);
319                    }
320                    Ok(SolutionSet::Ambiguous(ambiguity)) => {
321                        return Ok(SolutionSet::Ambiguous(ambiguity));
322                    }
323                    Ok(SolutionSet::Unique(generic_args)) => {
324                        let concrete_impl =
325                            ConcreteImplLongId { impl_def_id, generic_args }.intern(db);
326                        let impl_id = ImplLongId::Concrete(concrete_impl).intern(db);
327                        return Ok(SolutionSet::Unique(CanonicalImpl(impl_id)));
328                    }
329                    _ => {}
330                }
331            }
332            UninferredImpl::GenericParam(generic_param_id) => {
333                let impl_id = ImplLongId::GenericParameter(generic_param_id).intern(db);
334                return Ok(SolutionSet::Unique(CanonicalImpl(impl_id)));
335            }
336            // TODO(TomerStarkware): Try to solve for impl alias without inference.
337            UninferredImpl::ImplAlias(_) => {}
338            UninferredImpl::ImplImpl(_) | UninferredImpl::GeneratedImpl(_) => {}
339        }
340    }
341
342    let mut inference_data: InferenceData<'_> = InferenceData::new(InferenceId::Canonical);
343    let mut inference = inference_data.inference(db);
344    inference.data.impl_type_bounds = impl_type_bounds;
345    let (canonical_trait, canonical_embedding) = canonical_trait.embed(&mut inference);
346
347    // If the closure params are not var free, we cannot infer the negative impl.
348    // We use the canonical trait to concretize the closure params.
349    if let UninferredImpl::GeneratedImpl(imp) = candidate {
350        inference.conform_traits(imp.long(db).concrete_trait, canonical_trait.id)?;
351    }
352
353    // Instantiate the candidate in the inference table.
354    let candidate_impl =
355        inference.infer_impl(candidate, canonical_trait.id, lookup_context, None)?;
356    for (trait_type, ty) in canonical_trait.mappings.types.iter() {
357        let mapped_ty =
358            inference.reduce_impl_ty(ImplTypeId::new(candidate_impl, *trait_type, db))?;
359
360        // Conform the candidate's type to the trait's type.
361        inference.conform_ty(mapped_ty, *ty)?;
362    }
363    for (trait_const, const_id) in canonical_trait.mappings.constants.iter() {
364        let mapped_const_id = inference.reduce_impl_constant(ImplConstantId::new(
365            candidate_impl,
366            *trait_const,
367            db,
368        ))?;
369        // Conform the candidate's constant to the trait's constant.
370        inference.conform_const(mapped_const_id, *const_id)?;
371    }
372
373    for (trait_impl, impl_id) in canonical_trait.mappings.impls.iter() {
374        let mapped_impl_id =
375            inference.reduce_impl_impl(ImplImplId::new(candidate_impl, *trait_impl, db))?;
376        // Conform the candidate's impl to the trait's impl.
377        inference.conform_impl(mapped_impl_id, *impl_id)?;
378    }
379
380    let mut inference = inference_data.inference(db);
381    let solution_set = inference.solution_set()?;
382    Ok(match solution_set {
383        SolutionSet::None => SolutionSet::None,
384        SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
385        SolutionSet::Unique(_) => {
386            let candidate_impl = inference.rewrite(candidate_impl).no_err();
387            match CanonicalImpl::canonicalize(db, candidate_impl, &canonical_embedding) {
388                Ok(canonical_impl) => SolutionSet::Unique(canonical_impl),
389                Err(MapperError(var)) => {
390                    return Ok(SolutionSet::Ambiguous(Ambiguity::FreeVariable {
391                        impl_id: candidate_impl,
392                        var,
393                    }));
394                }
395            }
396        }
397    })
398}
399
400/// Enum for the result of `can_conform`.
401#[derive(Clone, Copy, Debug, PartialEq, Eq)]
402enum CanConformResult {
403    Accepted,
404    InferenceRequired,
405    Rejected,
406}
407
408impl CanConformResult {
409    fn fold(iter: impl IntoIterator<Item = CanConformResult>) -> CanConformResult {
410        let mut res = CanConformResult::Accepted; // Start with a default value of Accepted
411        for item in iter {
412            match item {
413                CanConformResult::Rejected => return CanConformResult::Rejected,
414                CanConformResult::Accepted => continue,
415                CanConformResult::InferenceRequired => {
416                    res = CanConformResult::InferenceRequired;
417                }
418            }
419        }
420        res // Return the final result
421    }
422}
423/// An inference without 'vars' that can be used to solve canonical traits which do not contain
424/// 'vars' or associated items.
425struct LiteInference<'db> {
426    db: &'db dyn Database,
427    substitution: GenericSubstitution<'db>,
428}
429
430impl<'db> LiteInference<'db> {
431    fn new(db: &'db dyn Database) -> Self {
432        LiteInference { db, substitution: GenericSubstitution::default() }
433    }
434
435    /// Tries to infer the generic arguments of the trait from the given params.
436    /// If the inference fails (i.e., requires full inference), returns an error.
437    fn infer_generic_assignment(
438        &mut self,
439        params: &[GenericParam<'db>],
440        lookup_context: ImplLookupContextId<'db>,
441        impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
442    ) -> InferenceResult<SolutionSet<'db, Vec<GenericArgumentId<'db>>>> {
443        let mut generic_args = Vec::with_capacity(params.len());
444        for param in params {
445            match param {
446                GenericParam::Type(generic_param_type) => {
447                    if self.substitution.contains_key(&generic_param_type.id) {
448                        generic_args.push(*self.substitution.get(&generic_param_type.id).unwrap());
449                    } else {
450                        // If the type is not in the substitution, we cannot solve it without
451                        // inference.
452                        return Err(super::ErrorSet);
453                    }
454                }
455                GenericParam::Const(generic_param_const) => {
456                    if self.substitution.contains_key(&generic_param_const.id) {
457                        generic_args.push(*self.substitution.get(&generic_param_const.id).unwrap());
458                    } else {
459                        // If the const is not in the substitution, we cannot solve it without
460                        // inference.
461                        return Err(super::ErrorSet);
462                    }
463                }
464                GenericParam::Impl(generic_param_impl) => {
465                    if !generic_param_impl.type_constraints.is_empty() {
466                        return Err(super::ErrorSet);
467                    }
468                    if self.substitution.contains_key(&generic_param_impl.id) {
469                        generic_args.push(*self.substitution.get(&generic_param_impl.id).unwrap());
470                        continue;
471                    }
472
473                    let Ok(Ok(imp_concrete_trait_id)) =
474                        self.substitution.substitute(self.db, generic_param_impl.concrete_trait)
475                    else {
476                        return Err(super::ErrorSet);
477                    };
478                    let canonical_trait = CanonicalTrait {
479                        id: imp_concrete_trait_id,
480                        mappings: ImplVarTraitItemMappings::default(),
481                    };
482                    let mut inner_context = lookup_context.long(self.db).clone();
483                    enrich_lookup_context(self.db, imp_concrete_trait_id, &mut inner_context);
484                    let Ok(solution) = self.db.canonic_trait_solutions(
485                        canonical_trait,
486                        inner_context.intern(self.db),
487                        (*impl_type_bounds).clone(),
488                    ) else {
489                        return Err(super::ErrorSet);
490                    };
491                    match solution {
492                        SolutionSet::None => return Ok(SolutionSet::None),
493                        SolutionSet::Unique(imp) => {
494                            self.substitution
495                                .insert(generic_param_impl.id, GenericArgumentId::Impl(imp.0));
496                            generic_args.push(GenericArgumentId::Impl(imp.0));
497                        }
498                        SolutionSet::Ambiguous(ambiguity) => {
499                            return Ok(SolutionSet::Ambiguous(ambiguity));
500                        }
501                    }
502                }
503                GenericParam::NegImpl(_) => return Err(super::ErrorSet),
504            }
505        }
506        Ok(SolutionSet::Unique(generic_args))
507    }
508
509    /// Checks if the generic arguments of the candidate could be conformed to the generic args of
510    /// the trait and if the trait or the candidate contain vars (which would require solving
511    /// using inference).
512    fn can_conform_generic_args(
513        &mut self,
514        (candidate_args, candidate_final): (&[GenericArgumentId<'db>], bool),
515        (target_args, target_final): (&[GenericArgumentId<'db>], bool),
516    ) -> CanConformResult {
517        CanConformResult::fold(zip_eq(candidate_args, target_args).map(
518            |(candidate_arg, target_arg)| {
519                self.can_conform_generic_arg(
520                    (*candidate_arg, candidate_final),
521                    (*target_arg, target_final),
522                )
523            },
524        ))
525    }
526
527    /// Checks if a [GenericArgumentId] of the candidate could be conformed to a [GenericArgumentId]
528    /// of the trait.
529    fn can_conform_generic_arg(
530        &mut self,
531        (candidate_arg, mut candidate_final): (GenericArgumentId<'db>, bool),
532        (target_arg, mut target_final): (GenericArgumentId<'db>, bool),
533    ) -> CanConformResult {
534        if candidate_arg == target_arg {
535            return CanConformResult::Accepted;
536        }
537        candidate_final = candidate_final || candidate_arg.is_fully_concrete(self.db);
538        target_final = target_final || target_arg.is_var_free(self.db);
539        if candidate_final && target_final {
540            return CanConformResult::Rejected;
541        }
542        match (candidate_arg, target_arg) {
543            (GenericArgumentId::Type(candidate), GenericArgumentId::Type(target)) => {
544                self.can_conform_ty((candidate, candidate_final), (target, target_final))
545            }
546            (GenericArgumentId::Constant(candidate), GenericArgumentId::Constant(target)) => {
547                self.can_conform_const((candidate, candidate_final), (target, target_final))
548            }
549            (GenericArgumentId::Impl(candidate), GenericArgumentId::Impl(target)) => {
550                self.can_conform_impl((candidate, candidate_final), (target, target_final))
551            }
552            (GenericArgumentId::NegImpl(_), GenericArgumentId::NegImpl(_)) => {
553                CanConformResult::InferenceRequired
554            }
555            _ => CanConformResult::Rejected,
556        }
557    }
558
559    /// Checks if a generic arg [TypeId] of the candidate could be conformed to a generic arg
560    /// [TypeId] of the trait.
561    fn can_conform_ty(
562        &mut self,
563        (candidate_ty, mut candidate_final): (TypeId<'db>, bool),
564        (target_ty, mut target_final): (TypeId<'db>, bool),
565    ) -> CanConformResult {
566        if candidate_ty == target_ty {
567            return CanConformResult::Accepted;
568        }
569        candidate_final = candidate_final || candidate_ty.is_fully_concrete(self.db);
570        target_final = target_final || target_ty.is_var_free(self.db);
571        if candidate_final && target_final {
572            return CanConformResult::Rejected;
573        }
574        let target_long_ty = target_ty.long(self.db);
575
576        if let TypeLongId::Var(_) | TypeLongId::NumericLiteral(_) = target_long_ty {
577            return CanConformResult::InferenceRequired;
578        }
579
580        let long_ty_candidate = candidate_ty.long(self.db);
581
582        match (long_ty_candidate, target_long_ty) {
583            (TypeLongId::Concrete(candidate), TypeLongId::Concrete(target)) => {
584                if candidate.generic_type(self.db) != target.generic_type(self.db) {
585                    return CanConformResult::Rejected;
586                }
587
588                self.can_conform_generic_args(
589                    (&candidate.generic_args(self.db), candidate_final),
590                    (&target.generic_args(self.db), target_final),
591                )
592            }
593            (TypeLongId::Concrete(_), _) => CanConformResult::Rejected,
594            (TypeLongId::Tuple(candidate_tys), TypeLongId::Tuple(target_tys)) => {
595                if candidate_tys.len() != target_tys.len() {
596                    return CanConformResult::Rejected;
597                }
598
599                CanConformResult::fold(zip_eq(candidate_tys, target_tys).map(
600                    |(candidate_subty, target_subty)| {
601                        self.can_conform_ty(
602                            (*candidate_subty, candidate_final),
603                            (*target_subty, target_final),
604                        )
605                    },
606                ))
607            }
608            (TypeLongId::Tuple(_), _) => CanConformResult::Rejected,
609            (TypeLongId::Closure(candidate), TypeLongId::Closure(target)) => {
610                if candidate.params_location != target.params_location {
611                    return CanConformResult::Rejected;
612                }
613
614                let params_check = CanConformResult::fold(
615                    zip_eq(candidate.param_tys.clone(), target.param_tys.clone()).map(
616                        |(candidate_subty, target_subty)| {
617                            self.can_conform_ty(
618                                (candidate_subty, candidate_final),
619                                (target_subty, target_final),
620                            )
621                        },
622                    ),
623                );
624                if params_check == CanConformResult::Rejected {
625                    return CanConformResult::Rejected;
626                }
627                let captured_types_check = CanConformResult::fold(
628                    zip_eq(candidate.captured_types.clone(), target.captured_types.clone()).map(
629                        |(candidate_subty, target_subty)| {
630                            self.can_conform_ty(
631                                (candidate_subty, candidate_final),
632                                (target_subty, target_final),
633                            )
634                        },
635                    ),
636                );
637                if captured_types_check == CanConformResult::Rejected {
638                    return CanConformResult::Rejected;
639                }
640                let return_type_check = self.can_conform_ty(
641                    (candidate.ret_ty, candidate_final),
642                    (target.ret_ty, target_final),
643                );
644                if return_type_check == CanConformResult::Rejected {
645                    return CanConformResult::Rejected;
646                }
647                if params_check == CanConformResult::InferenceRequired
648                    || captured_types_check == CanConformResult::InferenceRequired
649                    || return_type_check == CanConformResult::InferenceRequired
650                {
651                    return CanConformResult::InferenceRequired;
652                }
653                CanConformResult::Accepted
654            }
655            (TypeLongId::Closure(_), _) => CanConformResult::Rejected,
656            (
657                TypeLongId::FixedSizeArray { type_id: candidate_type_id, size: candidate_size },
658                TypeLongId::FixedSizeArray { type_id: target_type_id, size: target_size },
659            ) => CanConformResult::fold([
660                self.can_conform_const(
661                    (*candidate_size, candidate_final),
662                    (*target_size, target_final),
663                ),
664                self.can_conform_ty(
665                    (*candidate_type_id, candidate_final),
666                    (*target_type_id, target_final),
667                ),
668            ]),
669            (TypeLongId::FixedSizeArray { type_id: _, size: _ }, _) => CanConformResult::Rejected,
670            (TypeLongId::Snapshot(candidate_inner_ty), TypeLongId::Snapshot(target_inner_ty)) => {
671                self.can_conform_ty(
672                    (*candidate_inner_ty, candidate_final),
673                    (*target_inner_ty, target_final),
674                )
675            }
676            (TypeLongId::Snapshot(_), _) => CanConformResult::Rejected,
677            (TypeLongId::GenericParameter(param), _) => {
678                let mut res = CanConformResult::Accepted;
679                // if param not in substitution add it otherwise make sure it equal target_ty
680                match self.substitution.entry(*param) {
681                    Entry::Occupied(entry) => {
682                        if let GenericArgumentId::Type(existing_ty) = entry.get() {
683                            if *existing_ty != target_ty {
684                                res = CanConformResult::Rejected;
685                            }
686                            if !existing_ty.is_var_free(self.db) {
687                                return CanConformResult::InferenceRequired;
688                            }
689                        } else {
690                            res = CanConformResult::Rejected;
691                        }
692                    }
693                    Entry::Vacant(e) => {
694                        e.insert(GenericArgumentId::Type(target_ty));
695                    }
696                }
697
698                if target_ty.is_var_free(self.db) {
699                    res
700                } else {
701                    CanConformResult::InferenceRequired
702                }
703            }
704            (
705                TypeLongId::Var(_)
706                | TypeLongId::NumericLiteral(_)
707                | TypeLongId::ImplType(_)
708                | TypeLongId::Missing(_)
709                | TypeLongId::Coupon(_),
710                _,
711            ) => CanConformResult::InferenceRequired,
712        }
713    }
714
715    /// Checks if a generic arg [ImplId] of the candidate could be conformed to a generic arg
716    /// [ImplId] of the trait.
717    fn can_conform_impl(
718        &mut self,
719        (candidate_impl, mut candidate_final): (ImplId<'db>, bool),
720        (target_impl, mut target_final): (ImplId<'db>, bool),
721    ) -> CanConformResult {
722        let long_impl_trait = target_impl.long(self.db);
723        if candidate_impl == target_impl {
724            return CanConformResult::Accepted;
725        }
726        candidate_final = candidate_final || candidate_impl.is_fully_concrete(self.db);
727        target_final = target_final || target_impl.is_var_free(self.db);
728        if candidate_final && target_final {
729            return CanConformResult::Rejected;
730        }
731        if let ImplLongId::ImplVar(_) = long_impl_trait {
732            return CanConformResult::InferenceRequired;
733        }
734        match (candidate_impl.long(self.db), long_impl_trait) {
735            (ImplLongId::Concrete(candidate), ImplLongId::Concrete(target)) => {
736                let candidate = candidate.long(self.db);
737                let target = target.long(self.db);
738                if candidate.impl_def_id != target.impl_def_id {
739                    return CanConformResult::Rejected;
740                }
741                let candidate_args = candidate.generic_args.clone();
742                let target_args = target.generic_args.clone();
743                self.can_conform_generic_args(
744                    (&candidate_args, candidate_final),
745                    (&target_args, target_final),
746                )
747            }
748            (ImplLongId::Concrete(_), _) => CanConformResult::Rejected,
749            (ImplLongId::GenericParameter(param), _) => {
750                let mut res = CanConformResult::Accepted;
751                // if param not in substitution add it otherwise make sure it equal target_ty
752                match self.substitution.entry(*param) {
753                    Entry::Occupied(entry) => {
754                        if let GenericArgumentId::Impl(existing_impl) = entry.get() {
755                            if *existing_impl != target_impl {
756                                res = CanConformResult::Rejected;
757                            }
758                            if !existing_impl.is_var_free(self.db) {
759                                return CanConformResult::InferenceRequired;
760                            }
761                        } else {
762                            res = CanConformResult::Rejected;
763                        }
764                    }
765                    Entry::Vacant(e) => {
766                        e.insert(GenericArgumentId::Impl(target_impl));
767                    }
768                }
769
770                if target_impl.is_var_free(self.db) {
771                    res
772                } else {
773                    CanConformResult::InferenceRequired
774                }
775            }
776            (
777                ImplLongId::ImplVar(_)
778                | ImplLongId::ImplImpl(_)
779                | ImplLongId::SelfImpl(_)
780                | ImplLongId::GeneratedImpl(_),
781                _,
782            ) => CanConformResult::InferenceRequired,
783        }
784    }
785
786    /// Checks if a generic arg [ConstValueId] of the candidate could be conformed to a generic arg
787    /// [ConstValueId] of the trait.
788    fn can_conform_const(
789        &mut self,
790        (candidate_id, mut candidate_final): (ConstValueId<'db>, bool),
791        (target_id, mut target_final): (ConstValueId<'db>, bool),
792    ) -> CanConformResult {
793        if candidate_id == target_id {
794            return CanConformResult::Accepted;
795        }
796        candidate_final = candidate_final || candidate_id.is_fully_concrete(self.db);
797        target_final = target_final || target_id.is_var_free(self.db);
798        if candidate_final && target_final {
799            return CanConformResult::Rejected;
800        }
801        let target_long_const = target_id.long(self.db);
802        if let ConstValue::Var(_, _) = target_long_const {
803            return CanConformResult::InferenceRequired;
804        }
805        match (candidate_id.long(self.db), target_long_const) {
806            (
807                ConstValue::Int(big_int, type_id),
808                ConstValue::Int(target_big_int, target_type_id),
809            ) => {
810                if big_int != target_big_int {
811                    return CanConformResult::Rejected;
812                }
813                self.can_conform_ty((*type_id, candidate_final), (*target_type_id, target_final))
814            }
815            (ConstValue::Int(_, _), _) => CanConformResult::Rejected,
816            (
817                ConstValue::Struct(const_values, type_id),
818                ConstValue::Struct(target_const_values, target_type_id),
819            ) => {
820                if const_values.len() != target_const_values.len() {
821                    return CanConformResult::Rejected;
822                };
823                CanConformResult::fold(chain!(
824                    [self.can_conform_ty(
825                        (*type_id, candidate_final),
826                        (*target_type_id, target_final)
827                    )],
828                    zip_eq(const_values, target_const_values).map(
829                        |(const_value, target_const_value)| {
830                            self.can_conform_const(
831                                (*const_value, candidate_final),
832                                (*target_const_value, target_final),
833                            )
834                        }
835                    )
836                ))
837            }
838            (ConstValue::Struct(_, _), _) => CanConformResult::Rejected,
839
840            (
841                ConstValue::Enum(concrete_variant, const_value),
842                ConstValue::Enum(target_concrete_variant, target_const_value),
843            ) => CanConformResult::fold([
844                self.can_conform_ty(
845                    (concrete_variant.ty, candidate_final),
846                    (target_concrete_variant.ty, target_final),
847                ),
848                self.can_conform_const(
849                    (*const_value, candidate_final),
850                    (*target_const_value, target_final),
851                ),
852            ]),
853            (ConstValue::Enum(_, _), _) => CanConformResult::Rejected,
854            (ConstValue::NonZero(const_value), ConstValue::NonZero(target_const_value)) => self
855                .can_conform_const(
856                    (*const_value, candidate_final),
857                    (*target_const_value, target_final),
858                ),
859            (ConstValue::NonZero(_), _) => CanConformResult::Rejected,
860            (ConstValue::Generic(param), _) => {
861                let mut res = CanConformResult::Accepted;
862                match self.substitution.entry(*param) {
863                    Entry::Occupied(entry) => {
864                        if let GenericArgumentId::Constant(existing_const) = entry.get() {
865                            if *existing_const != target_id {
866                                res = CanConformResult::Rejected;
867                            }
868
869                            if !existing_const.is_var_free(self.db) {
870                                return CanConformResult::InferenceRequired;
871                            }
872                        } else {
873                            res = CanConformResult::Rejected;
874                        }
875                    }
876                    Entry::Vacant(e) => {
877                        e.insert(GenericArgumentId::Constant(target_id));
878                    }
879                }
880                if target_id.is_var_free(self.db) {
881                    res
882                } else {
883                    CanConformResult::InferenceRequired
884                }
885            }
886            (ConstValue::ImplConstant(_) | ConstValue::Var(_, _) | ConstValue::Missing(_), _) => {
887                CanConformResult::InferenceRequired
888            }
889        }
890    }
891}
892
893/// Trait for solver-related semantic queries.
894pub trait SemanticSolver<'db>: Database {
895    /// Returns the solution set for a canonical trait.
896    fn canonic_trait_solutions(
897        &'db self,
898        canonical_trait: CanonicalTrait<'db>,
899        lookup_context: ImplLookupContextId<'db>,
900        impl_type_bounds: BTreeMap<ImplTypeById<'db>, TypeId<'db>>,
901    ) -> Result<SolutionSet<'db, CanonicalImpl<'db>>, InferenceError<'db>> {
902        canonic_trait_solutions_tracked(
903            self.as_dyn_database(),
904            canonical_trait,
905            lookup_context,
906            impl_type_bounds,
907        )
908    }
909}
910impl<'db, T: Database + ?Sized> SemanticSolver<'db> for T {}