cairo_lang_semantic/expr/inference/
solver.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use cairo_lang_debug::DebugWithDb;
5use cairo_lang_defs::ids::LanguageElementId;
6use cairo_lang_proc_macros::SemanticObject;
7use cairo_lang_utils::Intern;
8use cairo_lang_utils::ordered_hash_map::Entry;
9use itertools::{Itertools, chain, zip_eq};
10use salsa::Database;
11
12use super::canonic::{CanonicalImpl, CanonicalTrait, MapperError, ResultNoErrEx};
13use super::conform::InferenceConform;
14use super::infers::InferenceEmbeddings;
15use super::{
16    ImplVarTraitItemMappings, InferenceData, InferenceError, InferenceId, InferenceResult,
17    InferenceVar, LocalImplVarId,
18};
19use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
20use crate::items::imp::{
21    ImplId, ImplImplId, ImplLongId, ImplLookupContext, ImplLookupContextId, ImplSemantic,
22    UninferredImpl, UninferredImplById, find_candidates_at_context,
23    find_closure_generated_candidate,
24};
25use crate::items::trt::TraitSemantic;
26use crate::substitution::{GenericSubstitution, SemanticRewriter};
27use crate::types::{ImplTypeById, ImplTypeId};
28use crate::{
29    ConcreteImplLongId, ConcreteTraitId, GenericArgumentId, GenericParam, TypeId, TypeLongId,
30};
31
32/// A generic solution set for an inference constraint system.
33#[derive(Clone, PartialEq, Eq, Debug)]
34pub enum SolutionSet<'db, T> {
35    None,
36    Unique(T),
37    Ambiguous(Ambiguity<'db>),
38}
39
40// Somewhat taken from the salsa::Update derive macro.
41unsafe impl<'db, T: salsa::Update> salsa::Update for SolutionSet<'db, T> {
42    unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
43        let old_pointer = unsafe { &mut *old_pointer };
44        match (old_pointer, new_value) {
45            (SolutionSet::None, SolutionSet::None) => false,
46            (SolutionSet::Unique(u1), SolutionSet::Unique(u2)) => unsafe {
47                salsa::plumbing::UpdateDispatch::<T>::maybe_update(u1, u2)
48            },
49            (SolutionSet::Ambiguous(ambiguity), SolutionSet::Ambiguous(ambiguity2)) => unsafe {
50                salsa::plumbing::UpdateDispatch::<Ambiguity<'db>>::maybe_update(
51                    ambiguity, ambiguity2,
52                )
53            },
54            (old_pointer, new_value) => {
55                *old_pointer = new_value;
56                true
57            }
58        }
59    }
60}
61
62/// Describes the kinds of inference ambiguities.
63#[derive(Clone, Debug, Eq, Hash, PartialEq, SemanticObject, salsa::Update)]
64pub enum Ambiguity<'db> {
65    MultipleImplsFound {
66        concrete_trait_id: ConcreteTraitId<'db>,
67        impls: Vec<ImplId<'db>>,
68    },
69    FreeVariable {
70        impl_id: ImplId<'db>,
71        #[dont_rewrite]
72        var: InferenceVar,
73    },
74    WillNotInfer(ConcreteTraitId<'db>),
75    NegativeImplWithUnresolvedGenericArgs {
76        concrete_trait_id: ConcreteTraitId<'db>,
77        ty: TypeId<'db>,
78    },
79}
80impl<'db> Ambiguity<'db> {
81    pub fn format(&self, db: &dyn Database) -> String {
82        match self {
83            Ambiguity::MultipleImplsFound { concrete_trait_id, impls } => {
84                let impls_str = impls.iter().map(|imp| format!("`{}`", imp.format(db))).join(", ");
85                format!(
86                    "Trait `{:?}` has multiple implementations, in: {impls_str}",
87                    concrete_trait_id.debug(db)
88                )
89            }
90            Ambiguity::FreeVariable { impl_id, var: _ } => {
91                format!("Candidate impl {:?} has an unused generic parameter.", impl_id.debug(db),)
92            }
93            Ambiguity::WillNotInfer(concrete_trait_id) => {
94                format!(
95                    "Cannot infer trait {:?}. First generic argument must be known.",
96                    concrete_trait_id.debug(db)
97                )
98            }
99            Ambiguity::NegativeImplWithUnresolvedGenericArgs { concrete_trait_id, ty } => {
100                format!(
101                    "Cannot infer negative impl in `{:?}` as it contains the unresolved type `{}`",
102                    concrete_trait_id.debug(db),
103                    ty.format(db)
104                )
105            }
106        }
107    }
108}
109
110/// Implementation of [SemanticSolver::canonic_trait_solutions].
111/// Assumes the lookup context is already enriched by [enrich_lookup_context].
112pub fn canonic_trait_solutions<'db>(
113    db: &'db dyn Database,
114    canonical_trait: CanonicalTrait<'db>,
115    lookup_context: ImplLookupContextId<'db>,
116    impl_type_bounds: BTreeMap<ImplTypeById<'db>, TypeId<'db>>,
117) -> Result<SolutionSet<'db, CanonicalImpl<'db>>, InferenceError<'db>> {
118    let mut concrete_trait_id = canonical_trait.id;
119    let impl_type_bounds = Arc::new(impl_type_bounds);
120    // If the trait is not fully concrete, we might be able to use the trait's items to find a
121    // more concrete trait.
122    if !concrete_trait_id.is_fully_concrete(db) && !canonical_trait.mappings.is_empty() {
123        match solve_canonical_trait(db, canonical_trait, lookup_context, impl_type_bounds.clone()) {
124            SolutionSet::None => {}
125            SolutionSet::Unique(imp) => {
126                concrete_trait_id =
127                    imp.0.concrete_trait(db).expect("A solved impl must have a concrete trait");
128            }
129            SolutionSet::Ambiguous(ambiguity) => {
130                return Ok(SolutionSet::Ambiguous(ambiguity));
131            }
132        }
133    }
134    // Solve the trait without the trait items, so we'd be able to find conflicting impls.
135    Ok(solve_canonical_trait(
136        db,
137        CanonicalTrait { id: concrete_trait_id, mappings: ImplVarTraitItemMappings::default() },
138        lookup_context,
139        impl_type_bounds,
140    ))
141}
142
143/// Query implementation of [SemanticSolver::canonic_trait_solutions].
144/// Assumes the lookup context is already enriched by [enrich_lookup_context].
145#[salsa::tracked(cycle_result=canonic_trait_solutions_cycle)]
146pub fn canonic_trait_solutions_tracked<'db>(
147    db: &'db dyn Database,
148    canonical_trait: CanonicalTrait<'db>,
149    lookup_context: ImplLookupContextId<'db>,
150    impl_type_bounds: BTreeMap<ImplTypeById<'db>, TypeId<'db>>,
151) -> Result<SolutionSet<'db, CanonicalImpl<'db>>, InferenceError<'db>> {
152    canonic_trait_solutions(db, canonical_trait, lookup_context, impl_type_bounds)
153}
154
155/// Cycle handling for [canonic_trait_solutions].
156pub fn canonic_trait_solutions_cycle<'db>(
157    _db: &dyn Database,
158    _canonical_trait: CanonicalTrait<'db>,
159    _lookup_context: ImplLookupContextId<'db>,
160    _impl_type_bounds: BTreeMap<ImplTypeById<'db>, TypeId<'db>>,
161) -> Result<SolutionSet<'db, CanonicalImpl<'db>>, InferenceError<'db>> {
162    Err(InferenceError::Cycle(InferenceVar::Impl(LocalImplVarId(0))))
163}
164
165/// Adds the defining module of the trait and the generic arguments to the lookup context.
166pub fn enrich_lookup_context<'db>(
167    db: &'db dyn Database,
168    concrete_trait_id: ConcreteTraitId<'db>,
169    lookup_context: &mut ImplLookupContext<'db>,
170) {
171    lookup_context.insert_module(concrete_trait_id.trait_id(db).parent_module(db), db);
172    let generic_args = concrete_trait_id.generic_args(db);
173    // Add the defining module of the generic args to the lookup.
174    for generic_arg in generic_args {
175        if let GenericArgumentId::Type(ty) = generic_arg {
176            enrich_lookup_context_with_ty(db, *ty, lookup_context);
177        }
178    }
179    lookup_context.strip_for_trait_id(db, concrete_trait_id.trait_id(db));
180}
181
182/// Adds the defining module of the type to the lookup context.
183pub fn enrich_lookup_context_with_ty<'db>(
184    db: &'db dyn Database,
185    ty: TypeId<'db>,
186    lookup_context: &mut ImplLookupContext<'db>,
187) {
188    match ty.long(db) {
189        TypeLongId::ImplType(impl_type_id) => {
190            lookup_context.insert_impl(impl_type_id.impl_id(), db);
191        }
192        long_ty => {
193            if let Some(module_id) = long_ty.module_id(db) {
194                lookup_context.insert_module(module_id, db);
195            }
196        }
197    }
198}
199
200/// Attempts to solve a `canonical_trait`. Will try to find candidates in the given
201/// `lookup_context`.
202fn solve_canonical_trait<'db>(
203    db: &'db dyn Database,
204    canonical_trait: CanonicalTrait<'db>,
205    lookup_context: ImplLookupContextId<'db>,
206    impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
207) -> SolutionSet<'db, CanonicalImpl<'db>> {
208    let filter = canonical_trait.id.filter(db);
209    let mut candidates = find_candidates_at_context(db, lookup_context, filter).unwrap_or_default();
210    find_closure_generated_candidate(db, canonical_trait.id)
211        .map(|candidate| candidates.insert(UninferredImplById(candidate)));
212
213    let mut unique_solution: Option<CanonicalImpl<'_>> = None;
214    for candidate in candidates.into_iter() {
215        let Ok(candidate_solution_set) = solve_candidate(
216            db,
217            &canonical_trait,
218            candidate.0,
219            lookup_context,
220            impl_type_bounds.clone(),
221        ) else {
222            continue;
223        };
224
225        let candidate_solution = match candidate_solution_set {
226            SolutionSet::None => continue,
227            SolutionSet::Unique(candidate_solution) => candidate_solution,
228            SolutionSet::Ambiguous(ambiguity) => return SolutionSet::Ambiguous(ambiguity),
229        };
230        if let Some(unique_solution) = unique_solution {
231            // There might be multiple unique solutions from different candidates that are
232            // solved to the same impl id (e.g. finding it near the trait, and
233            // through an impl alias). This is valid.
234            if unique_solution.0 != candidate_solution.0 {
235                return SolutionSet::Ambiguous(Ambiguity::MultipleImplsFound {
236                    concrete_trait_id: canonical_trait.id,
237                    impls: vec![unique_solution.0, candidate_solution.0],
238                });
239            }
240        }
241        unique_solution = Some(candidate_solution);
242    }
243    unique_solution.map(SolutionSet::Unique).unwrap_or(SolutionSet::None)
244}
245
246/// Attempts to solve `candidate` as the requested `canonical_trait`.
247fn solve_candidate<'db>(
248    db: &'db dyn Database,
249    canonical_trait: &CanonicalTrait<'db>,
250    candidate: UninferredImpl<'db>,
251    lookup_context: ImplLookupContextId<'db>,
252    impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
253) -> InferenceResult<SolutionSet<'db, CanonicalImpl<'db>>> {
254    let Ok(candidate_concrete_trait) = candidate.concrete_trait(db) else {
255        return Err(super::ErrorSet);
256    };
257    // If the candidate is fully concrete, or its a generic which is var free, there is nothing
258    // to substitute. A generic param may not be var free, if it contains impl types.
259    let candidate_final = matches!(candidate, UninferredImpl::GenericParam(_))
260        && candidate_concrete_trait.is_var_free(db)
261        || candidate_concrete_trait.is_fully_concrete(db);
262    let target_final = canonical_trait.id.is_var_free(db);
263    let mut lite_inference = LiteInference::new(db);
264    if candidate_final && target_final && candidate_concrete_trait != canonical_trait.id {
265        return Err(super::ErrorSet);
266    }
267
268    let mut res = lite_inference.can_conform_generic_args(
269        (candidate_concrete_trait.generic_args(db), candidate_final),
270        (canonical_trait.id.generic_args(db), target_final),
271    );
272
273    // If the candidate is a generic param, its trait is final and not substituted.
274    if matches!(candidate, UninferredImpl::GenericParam(_))
275        && !lite_inference.substitution.is_empty()
276    {
277        return Err(super::ErrorSet);
278    }
279
280    // If the trait has trait types, we default to using inference.
281    if res == CanConformResult::Accepted {
282        let Ok(trait_types) = db.trait_types(canonical_trait.id.trait_id(db)) else {
283            return Err(super::ErrorSet);
284        };
285        if !trait_types.is_empty() && !canonical_trait.mappings.types.is_empty() {
286            res = CanConformResult::InferenceRequired;
287        }
288    }
289
290    // Add the defining module of the candidate to the lookup.
291    let mut lookup_context = lookup_context.long(db).clone();
292    lookup_context.insert_lookup_scope(db, &candidate);
293    let lookup_context = lookup_context.intern(db);
294    if res == CanConformResult::Rejected {
295        return Err(super::ErrorSet);
296    } else if CanConformResult::Accepted == res {
297        match candidate {
298            UninferredImpl::Def(impl_def_id) => {
299                let imp_generic_params =
300                    db.impl_def_generic_params(impl_def_id).map_err(|_| super::ErrorSet)?;
301
302                match lite_inference.infer_generic_assignment(
303                    imp_generic_params,
304                    lookup_context,
305                    impl_type_bounds.clone(),
306                ) {
307                    Ok(SolutionSet::None) => {
308                        return Ok(SolutionSet::None);
309                    }
310                    Ok(SolutionSet::Ambiguous(ambiguity)) => {
311                        return Ok(SolutionSet::Ambiguous(ambiguity));
312                    }
313                    Ok(SolutionSet::Unique(generic_args)) => {
314                        let concrete_impl =
315                            ConcreteImplLongId { impl_def_id, generic_args }.intern(db);
316                        let impl_id = ImplLongId::Concrete(concrete_impl).intern(db);
317                        return Ok(SolutionSet::Unique(CanonicalImpl(impl_id)));
318                    }
319                    _ => {}
320                }
321            }
322            UninferredImpl::GenericParam(generic_param_id) => {
323                let impl_id = ImplLongId::GenericParameter(generic_param_id).intern(db);
324                return Ok(SolutionSet::Unique(CanonicalImpl(impl_id)));
325            }
326            // TODO(TomerStarkware): Try to solve for impl alias without inference.
327            UninferredImpl::ImplAlias(_) => {}
328            UninferredImpl::ImplImpl(_) | UninferredImpl::GeneratedImpl(_) => {}
329        }
330    }
331
332    let mut inference_data: InferenceData<'_> = InferenceData::new(InferenceId::Canonical);
333    let mut inference = inference_data.inference(db);
334    inference.data.impl_type_bounds = impl_type_bounds;
335    let (canonical_trait, canonical_embedding) = canonical_trait.embed(&mut inference);
336
337    // If the closure params are not var free, we cannot infer the negative impl.
338    // We use the canonical trait to concretize the closure params.
339    if let UninferredImpl::GeneratedImpl(imp) = candidate {
340        inference.conform_traits(imp.long(db).concrete_trait, canonical_trait.id)?;
341    }
342
343    // Instantiate the candidate in the inference table.
344    let candidate_impl =
345        inference.infer_impl(candidate, canonical_trait.id, lookup_context, None)?;
346    for (trait_type, ty) in canonical_trait.mappings.types.iter() {
347        let mapped_ty =
348            inference.reduce_impl_ty(ImplTypeId::new(candidate_impl, *trait_type, db))?;
349
350        // Conform the candidate's type to the trait's type.
351        inference.conform_ty(mapped_ty, *ty)?;
352    }
353    for (trait_const, const_id) in canonical_trait.mappings.constants.iter() {
354        let mapped_const_id = inference.reduce_impl_constant(ImplConstantId::new(
355            candidate_impl,
356            *trait_const,
357            db,
358        ))?;
359        // Conform the candidate's constant to the trait's constant.
360        inference.conform_const(mapped_const_id, *const_id)?;
361    }
362
363    for (trait_impl, impl_id) in canonical_trait.mappings.impls.iter() {
364        let mapped_impl_id =
365            inference.reduce_impl_impl(ImplImplId::new(candidate_impl, *trait_impl, db))?;
366        // Conform the candidate's impl to the trait's impl.
367        inference.conform_impl(mapped_impl_id, *impl_id)?;
368    }
369
370    let mut inference = inference_data.inference(db);
371    let solution_set = inference.solution_set()?;
372    Ok(match solution_set {
373        SolutionSet::None => SolutionSet::None,
374        SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
375        SolutionSet::Unique(_) => {
376            let candidate_impl = inference.rewrite(candidate_impl).no_err();
377            match CanonicalImpl::canonicalize(db, candidate_impl, &canonical_embedding) {
378                Ok(canonical_impl) => SolutionSet::Unique(canonical_impl),
379                Err(MapperError(var)) => {
380                    return Ok(SolutionSet::Ambiguous(Ambiguity::FreeVariable {
381                        impl_id: candidate_impl,
382                        var,
383                    }));
384                }
385            }
386        }
387    })
388}
389
390/// Enum for the result of `can_conform`.
391#[derive(Clone, Copy, Debug, PartialEq, Eq)]
392enum CanConformResult {
393    Accepted,
394    InferenceRequired,
395    Rejected,
396}
397
398impl CanConformResult {
399    fn fold(iter: impl IntoIterator<Item = CanConformResult>) -> CanConformResult {
400        let mut res = CanConformResult::Accepted; // Start with a default value of Accepted
401        for item in iter {
402            match item {
403                CanConformResult::Rejected => return CanConformResult::Rejected,
404                CanConformResult::Accepted => continue,
405                CanConformResult::InferenceRequired => {
406                    res = CanConformResult::InferenceRequired;
407                }
408            }
409        }
410        res // Return the final result
411    }
412}
413/// An inference without 'vars' that can be used to solve canonical traits which do not contain
414/// 'vars' or associated items.
415struct LiteInference<'db> {
416    db: &'db dyn Database,
417    substitution: GenericSubstitution<'db>,
418}
419
420impl<'db> LiteInference<'db> {
421    fn new(db: &'db dyn Database) -> Self {
422        LiteInference { db, substitution: GenericSubstitution::default() }
423    }
424
425    /// Tries to infer the generic arguments of the trait from the given params.
426    /// If the inference fails (i.e., requires full inference), returns an error.
427    fn infer_generic_assignment(
428        &mut self,
429        params: &[GenericParam<'db>],
430        lookup_context: ImplLookupContextId<'db>,
431        impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
432    ) -> InferenceResult<SolutionSet<'db, Vec<GenericArgumentId<'db>>>> {
433        let mut generic_args = Vec::with_capacity(params.len());
434        for param in params {
435            match param {
436                GenericParam::Type(generic_param_type) => {
437                    if self.substitution.contains_key(&generic_param_type.id) {
438                        generic_args.push(*self.substitution.get(&generic_param_type.id).unwrap());
439                    } else {
440                        // If the type is not in the substitution, we cannot solve it without
441                        // inference.
442                        return Err(super::ErrorSet);
443                    }
444                }
445                GenericParam::Const(generic_param_const) => {
446                    if self.substitution.contains_key(&generic_param_const.id) {
447                        generic_args.push(*self.substitution.get(&generic_param_const.id).unwrap());
448                    } else {
449                        // If the const is not in the substitution, we cannot solve it without
450                        // inference.
451                        return Err(super::ErrorSet);
452                    }
453                }
454                GenericParam::Impl(generic_param_impl) => {
455                    if !generic_param_impl.type_constraints.is_empty() {
456                        return Err(super::ErrorSet);
457                    }
458                    if self.substitution.contains_key(&generic_param_impl.id) {
459                        generic_args.push(*self.substitution.get(&generic_param_impl.id).unwrap());
460                        continue;
461                    }
462
463                    let Ok(Ok(imp_concrete_trait_id)) =
464                        self.substitution.substitute(self.db, generic_param_impl.concrete_trait)
465                    else {
466                        return Err(super::ErrorSet);
467                    };
468                    let canonical_trait = CanonicalTrait {
469                        id: imp_concrete_trait_id,
470                        mappings: ImplVarTraitItemMappings::default(),
471                    };
472                    let mut inner_context = lookup_context.long(self.db).clone();
473                    enrich_lookup_context(self.db, imp_concrete_trait_id, &mut inner_context);
474                    let Ok(solution) = self.db.canonic_trait_solutions(
475                        canonical_trait,
476                        inner_context.intern(self.db),
477                        (*impl_type_bounds).clone(),
478                    ) else {
479                        return Err(super::ErrorSet);
480                    };
481                    match solution {
482                        SolutionSet::None => return Ok(SolutionSet::None),
483                        SolutionSet::Unique(imp) => {
484                            self.substitution
485                                .insert(generic_param_impl.id, GenericArgumentId::Impl(imp.0));
486                            generic_args.push(GenericArgumentId::Impl(imp.0));
487                        }
488                        SolutionSet::Ambiguous(ambiguity) => {
489                            return Ok(SolutionSet::Ambiguous(ambiguity));
490                        }
491                    }
492                }
493                GenericParam::NegImpl(_) => return Err(super::ErrorSet),
494            }
495        }
496        Ok(SolutionSet::Unique(generic_args))
497    }
498
499    /// Checks if the generic arguments of the candidate could be conformed to the generic args of
500    /// the trait and if the trait or the candidate contain vars (which would require solving
501    /// using inference).
502    fn can_conform_generic_args(
503        &mut self,
504        (candidate_args, candidate_final): (&[GenericArgumentId<'db>], bool),
505        (target_args, target_final): (&[GenericArgumentId<'db>], bool),
506    ) -> CanConformResult {
507        CanConformResult::fold(zip_eq(candidate_args, target_args).map(
508            |(candidate_arg, target_arg)| {
509                self.can_conform_generic_arg(
510                    (*candidate_arg, candidate_final),
511                    (*target_arg, target_final),
512                )
513            },
514        ))
515    }
516
517    /// Checks if a [GenericArgumentId] of the candidate could be conformed to a [GenericArgumentId]
518    /// of the trait.
519    fn can_conform_generic_arg(
520        &mut self,
521        (candidate_arg, mut candidate_final): (GenericArgumentId<'db>, bool),
522        (target_arg, mut target_final): (GenericArgumentId<'db>, bool),
523    ) -> CanConformResult {
524        if candidate_arg == target_arg {
525            return CanConformResult::Accepted;
526        }
527        candidate_final = candidate_final || candidate_arg.is_fully_concrete(self.db);
528        target_final = target_final || target_arg.is_var_free(self.db);
529        if candidate_final && target_final {
530            return CanConformResult::Rejected;
531        }
532        match (candidate_arg, target_arg) {
533            (GenericArgumentId::Type(candidate), GenericArgumentId::Type(target)) => {
534                self.can_conform_ty((candidate, candidate_final), (target, target_final))
535            }
536            (GenericArgumentId::Constant(candidate), GenericArgumentId::Constant(target)) => {
537                self.can_conform_const((candidate, candidate_final), (target, target_final))
538            }
539            (GenericArgumentId::Impl(candidate), GenericArgumentId::Impl(target)) => {
540                self.can_conform_impl((candidate, candidate_final), (target, target_final))
541            }
542            (GenericArgumentId::NegImpl(_), GenericArgumentId::NegImpl(_)) => {
543                CanConformResult::InferenceRequired
544            }
545            _ => CanConformResult::Rejected,
546        }
547    }
548
549    /// Checks if a generic arg [TypeId] of the candidate could be conformed to a generic arg
550    /// [TypeId] of the trait.
551    fn can_conform_ty(
552        &mut self,
553        (candidate_ty, mut candidate_final): (TypeId<'db>, bool),
554        (target_ty, mut target_final): (TypeId<'db>, bool),
555    ) -> CanConformResult {
556        if candidate_ty == target_ty {
557            return CanConformResult::Accepted;
558        }
559        candidate_final = candidate_final || candidate_ty.is_fully_concrete(self.db);
560        target_final = target_final || target_ty.is_var_free(self.db);
561        if candidate_final && target_final {
562            return CanConformResult::Rejected;
563        }
564        let target_long_ty = target_ty.long(self.db);
565
566        if let TypeLongId::Var(_) = target_long_ty {
567            return CanConformResult::InferenceRequired;
568        }
569
570        let long_ty_candidate = candidate_ty.long(self.db);
571
572        match (long_ty_candidate, target_long_ty) {
573            (TypeLongId::Concrete(candidate), TypeLongId::Concrete(target)) => {
574                if candidate.generic_type(self.db) != target.generic_type(self.db) {
575                    return CanConformResult::Rejected;
576                }
577
578                self.can_conform_generic_args(
579                    (&candidate.generic_args(self.db), candidate_final),
580                    (&target.generic_args(self.db), target_final),
581                )
582            }
583            (TypeLongId::Concrete(_), _) => CanConformResult::Rejected,
584            (TypeLongId::Tuple(candidate_tys), TypeLongId::Tuple(target_tys)) => {
585                if candidate_tys.len() != target_tys.len() {
586                    return CanConformResult::Rejected;
587                }
588
589                CanConformResult::fold(zip_eq(candidate_tys, target_tys).map(
590                    |(candidate_subty, target_subty)| {
591                        self.can_conform_ty(
592                            (*candidate_subty, candidate_final),
593                            (*target_subty, target_final),
594                        )
595                    },
596                ))
597            }
598            (TypeLongId::Tuple(_), _) => CanConformResult::Rejected,
599            (TypeLongId::Closure(candidate), TypeLongId::Closure(target)) => {
600                if candidate.wrapper_location != target.wrapper_location {
601                    return CanConformResult::Rejected;
602                }
603
604                let params_check = CanConformResult::fold(
605                    zip_eq(candidate.param_tys.clone(), target.param_tys.clone()).map(
606                        |(candidate_subty, target_subty)| {
607                            self.can_conform_ty(
608                                (candidate_subty, candidate_final),
609                                (target_subty, target_final),
610                            )
611                        },
612                    ),
613                );
614                if params_check == CanConformResult::Rejected {
615                    return CanConformResult::Rejected;
616                }
617                let captured_types_check = CanConformResult::fold(
618                    zip_eq(candidate.captured_types.clone(), target.captured_types.clone()).map(
619                        |(candidate_subty, target_subty)| {
620                            self.can_conform_ty(
621                                (candidate_subty, candidate_final),
622                                (target_subty, target_final),
623                            )
624                        },
625                    ),
626                );
627                if captured_types_check == CanConformResult::Rejected {
628                    return CanConformResult::Rejected;
629                }
630                let return_type_check = self.can_conform_ty(
631                    (candidate.ret_ty, candidate_final),
632                    (target.ret_ty, target_final),
633                );
634                if return_type_check == CanConformResult::Rejected {
635                    return CanConformResult::Rejected;
636                }
637                if params_check == CanConformResult::InferenceRequired
638                    || captured_types_check == CanConformResult::InferenceRequired
639                    || return_type_check == CanConformResult::InferenceRequired
640                {
641                    return CanConformResult::InferenceRequired;
642                }
643                CanConformResult::Accepted
644            }
645            (TypeLongId::Closure(_), _) => CanConformResult::Rejected,
646            (
647                TypeLongId::FixedSizeArray { type_id: candidate_type_id, size: candidate_size },
648                TypeLongId::FixedSizeArray { type_id: target_type_id, size: target_size },
649            ) => CanConformResult::fold([
650                self.can_conform_const(
651                    (*candidate_size, candidate_final),
652                    (*target_size, target_final),
653                ),
654                self.can_conform_ty(
655                    (*candidate_type_id, candidate_final),
656                    (*target_type_id, target_final),
657                ),
658            ]),
659            (TypeLongId::FixedSizeArray { type_id: _, size: _ }, _) => CanConformResult::Rejected,
660            (TypeLongId::Snapshot(candidate_inner_ty), TypeLongId::Snapshot(target_inner_ty)) => {
661                self.can_conform_ty(
662                    (*candidate_inner_ty, candidate_final),
663                    (*target_inner_ty, target_final),
664                )
665            }
666            (TypeLongId::Snapshot(_), _) => CanConformResult::Rejected,
667            (TypeLongId::GenericParameter(param), _) => {
668                let mut res = CanConformResult::Accepted;
669                // if param not in substitution add it otherwise make sure it equal target_ty
670                match self.substitution.entry(*param) {
671                    Entry::Occupied(entry) => {
672                        if let GenericArgumentId::Type(existing_ty) = entry.get() {
673                            if *existing_ty != target_ty {
674                                res = CanConformResult::Rejected;
675                            }
676                            if !existing_ty.is_var_free(self.db) {
677                                return CanConformResult::InferenceRequired;
678                            }
679                        } else {
680                            res = CanConformResult::Rejected;
681                        }
682                    }
683                    Entry::Vacant(e) => {
684                        e.insert(GenericArgumentId::Type(target_ty));
685                    }
686                }
687
688                if target_ty.is_var_free(self.db) {
689                    res
690                } else {
691                    CanConformResult::InferenceRequired
692                }
693            }
694            (
695                TypeLongId::Var(_)
696                | TypeLongId::ImplType(_)
697                | TypeLongId::Missing(_)
698                | TypeLongId::Coupon(_),
699                _,
700            ) => CanConformResult::InferenceRequired,
701        }
702    }
703
704    /// Checks if a generic arg [ImplId] of the candidate could be conformed to a generic arg
705    /// [ImplId] of the trait.
706    fn can_conform_impl(
707        &mut self,
708        (candidate_impl, mut candidate_final): (ImplId<'db>, bool),
709        (target_impl, mut target_final): (ImplId<'db>, bool),
710    ) -> CanConformResult {
711        let long_impl_trait = target_impl.long(self.db);
712        if candidate_impl == target_impl {
713            return CanConformResult::Accepted;
714        }
715        candidate_final = candidate_final || candidate_impl.is_fully_concrete(self.db);
716        target_final = target_final || target_impl.is_var_free(self.db);
717        if candidate_final && target_final {
718            return CanConformResult::Rejected;
719        }
720        if let ImplLongId::ImplVar(_) = long_impl_trait {
721            return CanConformResult::InferenceRequired;
722        }
723        match (candidate_impl.long(self.db), long_impl_trait) {
724            (ImplLongId::Concrete(candidate), ImplLongId::Concrete(target)) => {
725                let candidate = candidate.long(self.db);
726                let target = target.long(self.db);
727                if candidate.impl_def_id != target.impl_def_id {
728                    return CanConformResult::Rejected;
729                }
730                let candidate_args = candidate.generic_args.clone();
731                let target_args = target.generic_args.clone();
732                self.can_conform_generic_args(
733                    (&candidate_args, candidate_final),
734                    (&target_args, target_final),
735                )
736            }
737            (ImplLongId::Concrete(_), _) => CanConformResult::Rejected,
738            (ImplLongId::GenericParameter(param), _) => {
739                let mut res = CanConformResult::Accepted;
740                // if param not in substitution add it otherwise make sure it equal target_ty
741                match self.substitution.entry(*param) {
742                    Entry::Occupied(entry) => {
743                        if let GenericArgumentId::Impl(existing_impl) = entry.get() {
744                            if *existing_impl != target_impl {
745                                res = CanConformResult::Rejected;
746                            }
747                            if !existing_impl.is_var_free(self.db) {
748                                return CanConformResult::InferenceRequired;
749                            }
750                        } else {
751                            res = CanConformResult::Rejected;
752                        }
753                    }
754                    Entry::Vacant(e) => {
755                        e.insert(GenericArgumentId::Impl(target_impl));
756                    }
757                }
758
759                if target_impl.is_var_free(self.db) {
760                    res
761                } else {
762                    CanConformResult::InferenceRequired
763                }
764            }
765            (
766                ImplLongId::ImplVar(_)
767                | ImplLongId::ImplImpl(_)
768                | ImplLongId::SelfImpl(_)
769                | ImplLongId::GeneratedImpl(_),
770                _,
771            ) => CanConformResult::InferenceRequired,
772        }
773    }
774
775    /// Checks if a generic arg [ConstValueId] of the candidate could be conformed to a generic arg
776    /// [ConstValueId] of the trait.
777    fn can_conform_const(
778        &mut self,
779        (candidate_id, mut candidate_final): (ConstValueId<'db>, bool),
780        (target_id, mut target_final): (ConstValueId<'db>, bool),
781    ) -> CanConformResult {
782        if candidate_id == target_id {
783            return CanConformResult::Accepted;
784        }
785        candidate_final = candidate_final || candidate_id.is_fully_concrete(self.db);
786        target_final = target_final || target_id.is_var_free(self.db);
787        if candidate_final && target_final {
788            return CanConformResult::Rejected;
789        }
790        let target_long_const = target_id.long(self.db);
791        if let ConstValue::Var(_, _) = target_long_const {
792            return CanConformResult::InferenceRequired;
793        }
794        match (candidate_id.long(self.db), target_long_const) {
795            (
796                ConstValue::Int(big_int, type_id),
797                ConstValue::Int(target_big_int, target_type_id),
798            ) => {
799                if big_int != target_big_int {
800                    return CanConformResult::Rejected;
801                }
802                self.can_conform_ty((*type_id, candidate_final), (*target_type_id, target_final))
803            }
804            (ConstValue::Int(_, _), _) => CanConformResult::Rejected,
805            (
806                ConstValue::Struct(const_values, type_id),
807                ConstValue::Struct(target_const_values, target_type_id),
808            ) => {
809                if const_values.len() != target_const_values.len() {
810                    return CanConformResult::Rejected;
811                };
812                CanConformResult::fold(chain!(
813                    [self.can_conform_ty(
814                        (*type_id, candidate_final),
815                        (*target_type_id, target_final)
816                    )],
817                    zip_eq(const_values, target_const_values).map(
818                        |(const_value, target_const_value)| {
819                            self.can_conform_const(
820                                (*const_value, candidate_final),
821                                (*target_const_value, target_final),
822                            )
823                        }
824                    )
825                ))
826            }
827            (ConstValue::Struct(_, _), _) => CanConformResult::Rejected,
828
829            (
830                ConstValue::Enum(concrete_variant, const_value),
831                ConstValue::Enum(target_concrete_variant, target_const_value),
832            ) => CanConformResult::fold([
833                self.can_conform_ty(
834                    (concrete_variant.ty, candidate_final),
835                    (target_concrete_variant.ty, target_final),
836                ),
837                self.can_conform_const(
838                    (*const_value, candidate_final),
839                    (*target_const_value, target_final),
840                ),
841            ]),
842            (ConstValue::Enum(_, _), _) => CanConformResult::Rejected,
843            (ConstValue::NonZero(const_value), ConstValue::NonZero(target_const_value)) => self
844                .can_conform_const(
845                    (*const_value, candidate_final),
846                    (*target_const_value, target_final),
847                ),
848            (ConstValue::NonZero(_), _) => CanConformResult::Rejected,
849            (ConstValue::Generic(param), _) => {
850                let mut res = CanConformResult::Accepted;
851                match self.substitution.entry(*param) {
852                    Entry::Occupied(entry) => {
853                        if let GenericArgumentId::Constant(existing_const) = entry.get() {
854                            if *existing_const != target_id {
855                                res = CanConformResult::Rejected;
856                            }
857
858                            if !existing_const.is_var_free(self.db) {
859                                return CanConformResult::InferenceRequired;
860                            }
861                        } else {
862                            res = CanConformResult::Rejected;
863                        }
864                    }
865                    Entry::Vacant(e) => {
866                        e.insert(GenericArgumentId::Constant(target_id));
867                    }
868                }
869                if target_id.is_var_free(self.db) {
870                    res
871                } else {
872                    CanConformResult::InferenceRequired
873                }
874            }
875            (ConstValue::ImplConstant(_) | ConstValue::Var(_, _) | ConstValue::Missing(_), _) => {
876                CanConformResult::InferenceRequired
877            }
878        }
879    }
880}
881
882/// Trait for solver-related semantic queries.
883pub trait SemanticSolver<'db>: Database {
884    /// Returns the solution set for a canonical trait.
885    fn canonic_trait_solutions(
886        &'db self,
887        canonical_trait: CanonicalTrait<'db>,
888        lookup_context: ImplLookupContextId<'db>,
889        impl_type_bounds: BTreeMap<ImplTypeById<'db>, TypeId<'db>>,
890    ) -> Result<SolutionSet<'db, CanonicalImpl<'db>>, InferenceError<'db>> {
891        canonic_trait_solutions_tracked(
892            self.as_dyn_database(),
893            canonical_trait,
894            lookup_context,
895            impl_type_bounds,
896        )
897    }
898}
899impl<'db, T: Database + ?Sized> SemanticSolver<'db> for T {}