cairo_lang_semantic/expr/inference/
solver.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use cairo_lang_debug::DebugWithDb;
5use cairo_lang_defs::ids::LanguageElementId;
6use cairo_lang_proc_macros::SemanticObject;
7use cairo_lang_utils::Intern;
8use cairo_lang_utils::ordered_hash_map::Entry;
9use itertools::{Itertools, chain, zip_eq};
10use salsa::Database;
11
12use super::canonic::{CanonicalImpl, CanonicalTrait, MapperError, ResultNoErrEx};
13use super::conform::InferenceConform;
14use super::infers::InferenceEmbeddings;
15use super::{
16    ImplVarTraitItemMappings, InferenceData, InferenceError, InferenceId, InferenceResult,
17    InferenceVar, LocalImplVarId,
18};
19use crate::items::constant::{ConstValue, ConstValueId, ImplConstantId};
20use crate::items::imp::{
21    ImplId, ImplImplId, ImplLongId, ImplLookupContext, ImplLookupContextId, ImplSemantic,
22    UninferredImpl, UninferredImplById, find_candidates_at_context,
23    find_closure_generated_candidate,
24};
25use crate::items::trt::TraitSemantic;
26use crate::substitution::{GenericSubstitution, SemanticRewriter};
27use crate::types::{ImplTypeById, ImplTypeId};
28use crate::{
29    ConcreteImplLongId, ConcreteTraitId, GenericArgumentId, GenericParam, TypeId, TypeLongId,
30};
31
32/// A generic solution set for an inference constraint system.
33#[derive(Clone, PartialEq, Eq, Debug)]
34pub enum SolutionSet<'db, T> {
35    None,
36    Unique(T),
37    Ambiguous(Ambiguity<'db>),
38}
39
40// Somewhat taken from the salsa::Update derive macro.
41unsafe impl<'db, T: salsa::Update> salsa::Update for SolutionSet<'db, T> {
42    unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
43        let old_pointer = unsafe { &mut *old_pointer };
44        match (old_pointer, new_value) {
45            (SolutionSet::None, SolutionSet::None) => false,
46            (SolutionSet::Unique(u1), SolutionSet::Unique(u2)) => unsafe {
47                salsa::plumbing::UpdateDispatch::<T>::maybe_update(u1, u2)
48            },
49            (SolutionSet::Ambiguous(ambiguity), SolutionSet::Ambiguous(ambiguity2)) => unsafe {
50                salsa::plumbing::UpdateDispatch::<Ambiguity<'db>>::maybe_update(
51                    ambiguity, ambiguity2,
52                )
53            },
54            (old_pointer, new_value) => {
55                *old_pointer = new_value;
56                true
57            }
58        }
59    }
60}
61
62/// Describes the kinds of inference ambiguities.
63#[derive(Clone, Debug, Eq, Hash, PartialEq, SemanticObject, salsa::Update)]
64pub enum Ambiguity<'db> {
65    MultipleImplsFound {
66        concrete_trait_id: ConcreteTraitId<'db>,
67        impls: Vec<ImplId<'db>>,
68    },
69    FreeVariable {
70        impl_id: ImplId<'db>,
71        #[dont_rewrite]
72        var: InferenceVar,
73    },
74    WillNotInfer(ConcreteTraitId<'db>),
75    NegativeImplWithUnsupportedExtractedArgs(GenericArgumentId<'db>),
76}
77impl<'db> Ambiguity<'db> {
78    pub fn format(&self, db: &dyn Database) -> String {
79        match self {
80            Ambiguity::MultipleImplsFound { concrete_trait_id, impls } => {
81                let impls_str = impls.iter().map(|imp| format!("`{}`", imp.format(db))).join(", ");
82                format!(
83                    "Trait `{:?}` has multiple implementations, in: {impls_str}",
84                    concrete_trait_id.debug(db)
85                )
86            }
87            Ambiguity::FreeVariable { impl_id, var: _ } => {
88                format!("Candidate impl {:?} has an unused generic parameter.", impl_id.debug(db),)
89            }
90            Ambiguity::WillNotInfer(concrete_trait_id) => {
91                format!(
92                    "Cannot infer trait {:?}. First generic argument must be known.",
93                    concrete_trait_id.debug(db)
94                )
95            }
96            Ambiguity::NegativeImplWithUnsupportedExtractedArgs(garg) => {
97                format!("Negative impl has an unsupported generic argument {:?}.", garg.debug(db),)
98            }
99        }
100    }
101}
102
103/// Implementation of [SemanticSolver::canonic_trait_solutions].
104/// Assumes the lookup context is already enriched by [enrich_lookup_context].
105pub fn canonic_trait_solutions<'db>(
106    db: &'db dyn Database,
107    canonical_trait: CanonicalTrait<'db>,
108    lookup_context: ImplLookupContextId<'db>,
109    impl_type_bounds: BTreeMap<ImplTypeById<'db>, TypeId<'db>>,
110) -> Result<SolutionSet<'db, CanonicalImpl<'db>>, InferenceError<'db>> {
111    let mut concrete_trait_id = canonical_trait.id;
112    let impl_type_bounds = Arc::new(impl_type_bounds);
113    // If the trait is not fully concrete, we might be able to use the trait's items to find a
114    // more concrete trait.
115    if !concrete_trait_id.is_fully_concrete(db) && !canonical_trait.mappings.is_empty() {
116        match solve_canonical_trait(db, canonical_trait, lookup_context, impl_type_bounds.clone()) {
117            SolutionSet::None => {}
118            SolutionSet::Unique(imp) => {
119                concrete_trait_id =
120                    imp.0.concrete_trait(db).expect("A solved impl must have a concrete trait");
121            }
122            SolutionSet::Ambiguous(ambiguity) => {
123                return Ok(SolutionSet::Ambiguous(ambiguity));
124            }
125        }
126    }
127    // Solve the trait without the trait items, so we'd be able to find conflicting impls.
128    Ok(solve_canonical_trait(
129        db,
130        CanonicalTrait { id: concrete_trait_id, mappings: ImplVarTraitItemMappings::default() },
131        lookup_context,
132        impl_type_bounds,
133    ))
134}
135
136/// Query implementation of [SemanticSolver::canonic_trait_solutions].
137/// Assumes the lookup context is already enriched by [enrich_lookup_context].
138#[salsa::tracked(cycle_result=canonic_trait_solutions_cycle)]
139pub fn canonic_trait_solutions_tracked<'db>(
140    db: &'db dyn Database,
141    canonical_trait: CanonicalTrait<'db>,
142    lookup_context: ImplLookupContextId<'db>,
143    impl_type_bounds: BTreeMap<ImplTypeById<'db>, TypeId<'db>>,
144) -> Result<SolutionSet<'db, CanonicalImpl<'db>>, InferenceError<'db>> {
145    canonic_trait_solutions(db, canonical_trait, lookup_context, impl_type_bounds)
146}
147
148/// Cycle handling for [canonic_trait_solutions].
149pub fn canonic_trait_solutions_cycle<'db>(
150    _db: &dyn Database,
151    _canonical_trait: CanonicalTrait<'db>,
152    _lookup_context: ImplLookupContextId<'db>,
153    _impl_type_bounds: BTreeMap<ImplTypeById<'db>, TypeId<'db>>,
154) -> Result<SolutionSet<'db, CanonicalImpl<'db>>, InferenceError<'db>> {
155    Err(InferenceError::Cycle(InferenceVar::Impl(LocalImplVarId(0))))
156}
157
158/// Adds the defining module of the trait and the generic arguments to the lookup context.
159pub fn enrich_lookup_context<'db>(
160    db: &'db dyn Database,
161    concrete_trait_id: ConcreteTraitId<'db>,
162    lookup_context: &mut ImplLookupContext<'db>,
163) {
164    lookup_context.insert_module(concrete_trait_id.trait_id(db).parent_module(db), db);
165    let generic_args = concrete_trait_id.generic_args(db);
166    // Add the defining module of the generic args to the lookup.
167    for generic_arg in generic_args {
168        if let GenericArgumentId::Type(ty) = generic_arg {
169            enrich_lookup_context_with_ty(db, *ty, lookup_context);
170        }
171    }
172    lookup_context.strip_for_trait_id(db, concrete_trait_id.trait_id(db));
173}
174
175/// Adds the defining module of the type to the lookup context.
176pub fn enrich_lookup_context_with_ty<'db>(
177    db: &'db dyn Database,
178    ty: TypeId<'db>,
179    lookup_context: &mut ImplLookupContext<'db>,
180) {
181    match ty.long(db) {
182        TypeLongId::ImplType(impl_type_id) => {
183            lookup_context.insert_impl(impl_type_id.impl_id(), db);
184        }
185        long_ty => {
186            if let Some(module_id) = long_ty.module_id(db) {
187                lookup_context.insert_module(module_id, db);
188            }
189        }
190    }
191}
192
193/// Attempts to solve a `canonical_trait`. Will try to find candidates in the given
194/// `lookup_context`.
195fn solve_canonical_trait<'db>(
196    db: &'db dyn Database,
197    canonical_trait: CanonicalTrait<'db>,
198    lookup_context: ImplLookupContextId<'db>,
199    impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
200) -> SolutionSet<'db, CanonicalImpl<'db>> {
201    let filter = canonical_trait.id.filter(db);
202    let mut candidates = find_candidates_at_context(db, lookup_context, filter).unwrap_or_default();
203    find_closure_generated_candidate(db, canonical_trait.id)
204        .map(|candidate| candidates.insert(UninferredImplById(candidate)));
205
206    let mut unique_solution: Option<CanonicalImpl<'_>> = None;
207    for candidate in candidates.into_iter() {
208        let Ok(candidate_solution_set) = solve_candidate(
209            db,
210            &canonical_trait,
211            candidate.0,
212            lookup_context,
213            impl_type_bounds.clone(),
214        ) else {
215            continue;
216        };
217
218        let candidate_solution = match candidate_solution_set {
219            SolutionSet::None => continue,
220            SolutionSet::Unique(candidate_solution) => candidate_solution,
221            SolutionSet::Ambiguous(ambiguity) => return SolutionSet::Ambiguous(ambiguity),
222        };
223        if let Some(unique_solution) = unique_solution {
224            // There might be multiple unique solutions from different candidates that are
225            // solved to the same impl id (e.g. finding it near the trait, and
226            // through an impl alias). This is valid.
227            if unique_solution.0 != candidate_solution.0 {
228                return SolutionSet::Ambiguous(Ambiguity::MultipleImplsFound {
229                    concrete_trait_id: canonical_trait.id,
230                    impls: vec![unique_solution.0, candidate_solution.0],
231                });
232            }
233        }
234        unique_solution = Some(candidate_solution);
235    }
236    unique_solution.map(SolutionSet::Unique).unwrap_or(SolutionSet::None)
237}
238
239/// Attempts to solve `candidate` as the requested `canonical_trait`.
240fn solve_candidate<'db>(
241    db: &'db dyn Database,
242    canonical_trait: &CanonicalTrait<'db>,
243    candidate: UninferredImpl<'db>,
244    lookup_context: ImplLookupContextId<'db>,
245    impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
246) -> InferenceResult<SolutionSet<'db, CanonicalImpl<'db>>> {
247    let Ok(candidate_concrete_trait) = candidate.concrete_trait(db) else {
248        return Err(super::ErrorSet);
249    };
250    // If the candidate is fully concrete, or it's a generic which is var free, there is nothing
251    // to substitute. A generic param may not be var free, if it contains impl types.
252    let candidate_final = matches!(candidate, UninferredImpl::GenericParam(_))
253        && candidate_concrete_trait.is_var_free(db)
254        || candidate_concrete_trait.is_fully_concrete(db);
255    let target_final = canonical_trait.id.is_var_free(db);
256    let mut lite_inference = LiteInference::new(db);
257    if candidate_final && target_final && candidate_concrete_trait != canonical_trait.id {
258        return Err(super::ErrorSet);
259    }
260
261    let mut res = lite_inference.can_conform_generic_args(
262        (candidate_concrete_trait.generic_args(db), candidate_final),
263        (canonical_trait.id.generic_args(db), target_final),
264    );
265
266    // If the candidate is a generic param, its trait is final and not substituted.
267    if matches!(candidate, UninferredImpl::GenericParam(_))
268        && !lite_inference.substitution.is_empty()
269    {
270        return Err(super::ErrorSet);
271    }
272
273    // If the trait has trait types, we default to using inference.
274    if res == CanConformResult::Accepted {
275        let Ok(trait_types) = db.trait_types(canonical_trait.id.trait_id(db)) else {
276            return Err(super::ErrorSet);
277        };
278        if !trait_types.is_empty() && !canonical_trait.mappings.types.is_empty() {
279            res = CanConformResult::InferenceRequired;
280        }
281    }
282
283    // Add the defining module of the candidate to the lookup.
284    let mut lookup_context = lookup_context.long(db).clone();
285    lookup_context.insert_lookup_scope(db, &candidate);
286    let lookup_context = lookup_context.intern(db);
287    if res == CanConformResult::Rejected {
288        return Err(super::ErrorSet);
289    } else if CanConformResult::Accepted == res {
290        match candidate {
291            UninferredImpl::Def(impl_def_id) => {
292                let imp_generic_params =
293                    db.impl_def_generic_params(impl_def_id).map_err(|_| super::ErrorSet)?;
294
295                match lite_inference.infer_generic_assignment(
296                    imp_generic_params,
297                    lookup_context,
298                    impl_type_bounds.clone(),
299                ) {
300                    Ok(SolutionSet::None) => {
301                        return Ok(SolutionSet::None);
302                    }
303                    Ok(SolutionSet::Ambiguous(ambiguity)) => {
304                        return Ok(SolutionSet::Ambiguous(ambiguity));
305                    }
306                    Ok(SolutionSet::Unique(generic_args)) => {
307                        let concrete_impl =
308                            ConcreteImplLongId { impl_def_id, generic_args }.intern(db);
309                        let impl_id = ImplLongId::Concrete(concrete_impl).intern(db);
310                        return Ok(SolutionSet::Unique(CanonicalImpl(impl_id)));
311                    }
312                    _ => {}
313                }
314            }
315            UninferredImpl::GenericParam(generic_param_id) => {
316                let impl_id = ImplLongId::GenericParameter(generic_param_id).intern(db);
317                return Ok(SolutionSet::Unique(CanonicalImpl(impl_id)));
318            }
319            // TODO(TomerStarkware): Try to solve for impl alias without inference.
320            UninferredImpl::ImplAlias(_) => {}
321            UninferredImpl::ImplImpl(_) | UninferredImpl::GeneratedImpl(_) => {}
322        }
323    }
324
325    let mut inference_data: InferenceData<'_> = InferenceData::new(InferenceId::Canonical);
326    let mut inference = inference_data.inference(db);
327    inference.data.impl_type_bounds = impl_type_bounds;
328    let (canonical_trait, canonical_embedding) = canonical_trait.embed(&mut inference);
329
330    // If the closure params are not var free, we cannot infer the negative impl.
331    // We use the canonical trait to concretize the closure params.
332    if let UninferredImpl::GeneratedImpl(imp) = candidate {
333        inference.conform_traits(imp.long(db).concrete_trait, canonical_trait.id)?;
334    }
335
336    // Instantiate the candidate in the inference table.
337    let candidate_impl =
338        inference.infer_impl(candidate, canonical_trait.id, lookup_context, None)?;
339    for (trait_type, ty) in canonical_trait.mappings.types.iter() {
340        let mapped_ty =
341            inference.reduce_impl_ty(ImplTypeId::new(candidate_impl, *trait_type, db))?;
342
343        // Conform the candidate's type to the trait's type.
344        inference.conform_ty(mapped_ty, *ty)?;
345    }
346    for (trait_const, const_id) in canonical_trait.mappings.constants.iter() {
347        let mapped_const_id = inference.reduce_impl_constant(ImplConstantId::new(
348            candidate_impl,
349            *trait_const,
350            db,
351        ))?;
352        // Conform the candidate's constant to the trait's constant.
353        inference.conform_const(mapped_const_id, *const_id)?;
354    }
355
356    for (trait_impl, impl_id) in canonical_trait.mappings.impls.iter() {
357        let mapped_impl_id =
358            inference.reduce_impl_impl(ImplImplId::new(candidate_impl, *trait_impl, db))?;
359        // Conform the candidate's impl to the trait's impl.
360        inference.conform_impl(mapped_impl_id, *impl_id)?;
361    }
362
363    let mut inference = inference_data.inference(db);
364    let solution_set = inference.solution_set()?;
365    Ok(match solution_set {
366        SolutionSet::None => SolutionSet::None,
367        SolutionSet::Ambiguous(ambiguity) => SolutionSet::Ambiguous(ambiguity),
368        SolutionSet::Unique(_) => {
369            let candidate_impl = inference.rewrite(candidate_impl).no_err();
370            match CanonicalImpl::canonicalize(db, candidate_impl, &canonical_embedding) {
371                Ok(canonical_impl) => SolutionSet::Unique(canonical_impl),
372                Err(MapperError(var)) => {
373                    return Ok(SolutionSet::Ambiguous(Ambiguity::FreeVariable {
374                        impl_id: candidate_impl,
375                        var,
376                    }));
377                }
378            }
379        }
380    })
381}
382
383/// Enum for the result of `can_conform`.
384#[derive(Clone, Copy, Debug, PartialEq, Eq)]
385enum CanConformResult {
386    Accepted,
387    InferenceRequired,
388    Rejected,
389}
390
391impl CanConformResult {
392    fn fold(iter: impl IntoIterator<Item = CanConformResult>) -> CanConformResult {
393        let mut res = CanConformResult::Accepted; // Start with a default value of Accepted
394        for item in iter {
395            match item {
396                CanConformResult::Rejected => return CanConformResult::Rejected,
397                CanConformResult::Accepted => continue,
398                CanConformResult::InferenceRequired => {
399                    res = CanConformResult::InferenceRequired;
400                }
401            }
402        }
403        res // Return the final result
404    }
405}
406/// An inference without 'vars' that can be used to solve canonical traits which do not contain
407/// 'vars' or associated items.
408struct LiteInference<'db> {
409    db: &'db dyn Database,
410    substitution: GenericSubstitution<'db>,
411}
412
413impl<'db> LiteInference<'db> {
414    fn new(db: &'db dyn Database) -> Self {
415        LiteInference { db, substitution: GenericSubstitution::default() }
416    }
417
418    /// Tries to infer the generic arguments of the trait from the given params.
419    /// If the inference fails (i.e., requires full inference), returns an error.
420    fn infer_generic_assignment(
421        &mut self,
422        params: &[GenericParam<'db>],
423        lookup_context: ImplLookupContextId<'db>,
424        impl_type_bounds: Arc<BTreeMap<ImplTypeById<'db>, TypeId<'db>>>,
425    ) -> InferenceResult<SolutionSet<'db, Vec<GenericArgumentId<'db>>>> {
426        let mut generic_args = Vec::with_capacity(params.len());
427        for param in params {
428            match param {
429                GenericParam::Type(generic_param_type) => {
430                    if self.substitution.contains_key(&generic_param_type.id) {
431                        generic_args.push(*self.substitution.get(&generic_param_type.id).unwrap());
432                    } else {
433                        // If the type is not in the substitution, we cannot solve it without
434                        // inference.
435                        return Err(super::ErrorSet);
436                    }
437                }
438                GenericParam::Const(generic_param_const) => {
439                    if self.substitution.contains_key(&generic_param_const.id) {
440                        generic_args.push(*self.substitution.get(&generic_param_const.id).unwrap());
441                    } else {
442                        // If the const is not in the substitution, we cannot solve it without
443                        // inference.
444                        return Err(super::ErrorSet);
445                    }
446                }
447                GenericParam::Impl(generic_param_impl) => {
448                    if !generic_param_impl.type_constraints.is_empty() {
449                        return Err(super::ErrorSet);
450                    }
451                    if self.substitution.contains_key(&generic_param_impl.id) {
452                        generic_args.push(*self.substitution.get(&generic_param_impl.id).unwrap());
453                        continue;
454                    }
455
456                    let Ok(Ok(imp_concrete_trait_id)) =
457                        self.substitution.substitute(self.db, generic_param_impl.concrete_trait)
458                    else {
459                        return Err(super::ErrorSet);
460                    };
461                    let canonical_trait = CanonicalTrait {
462                        id: imp_concrete_trait_id,
463                        mappings: ImplVarTraitItemMappings::default(),
464                    };
465                    let mut inner_context = lookup_context.long(self.db).clone();
466                    enrich_lookup_context(self.db, imp_concrete_trait_id, &mut inner_context);
467                    let Ok(solution) = self.db.canonic_trait_solutions(
468                        canonical_trait,
469                        inner_context.intern(self.db),
470                        (*impl_type_bounds).clone(),
471                    ) else {
472                        return Err(super::ErrorSet);
473                    };
474                    match solution {
475                        SolutionSet::None => return Ok(SolutionSet::None),
476                        SolutionSet::Unique(imp) => {
477                            self.substitution
478                                .insert(generic_param_impl.id, GenericArgumentId::Impl(imp.0));
479                            generic_args.push(GenericArgumentId::Impl(imp.0));
480                        }
481                        SolutionSet::Ambiguous(ambiguity) => {
482                            return Ok(SolutionSet::Ambiguous(ambiguity));
483                        }
484                    }
485                }
486                GenericParam::NegImpl(_) => return Err(super::ErrorSet),
487            }
488        }
489        Ok(SolutionSet::Unique(generic_args))
490    }
491
492    /// Checks if the generic arguments of the candidate could be conformed to the generic args of
493    /// the trait and if the trait or the candidate contain vars (which would require solving
494    /// using inference).
495    fn can_conform_generic_args(
496        &mut self,
497        (candidate_args, candidate_final): (&[GenericArgumentId<'db>], bool),
498        (target_args, target_final): (&[GenericArgumentId<'db>], bool),
499    ) -> CanConformResult {
500        CanConformResult::fold(zip_eq(candidate_args, target_args).map(
501            |(candidate_arg, target_arg)| {
502                self.can_conform_generic_arg(
503                    (*candidate_arg, candidate_final),
504                    (*target_arg, target_final),
505                )
506            },
507        ))
508    }
509
510    /// Checks if a [GenericArgumentId] of the candidate could be conformed to a [GenericArgumentId]
511    /// of the trait.
512    fn can_conform_generic_arg(
513        &mut self,
514        (candidate_arg, mut candidate_final): (GenericArgumentId<'db>, bool),
515        (target_arg, mut target_final): (GenericArgumentId<'db>, bool),
516    ) -> CanConformResult {
517        if candidate_arg == target_arg {
518            return CanConformResult::Accepted;
519        }
520        candidate_final = candidate_final || candidate_arg.is_fully_concrete(self.db);
521        target_final = target_final || target_arg.is_var_free(self.db);
522        if candidate_final && target_final {
523            return CanConformResult::Rejected;
524        }
525        match (candidate_arg, target_arg) {
526            (GenericArgumentId::Type(candidate), GenericArgumentId::Type(target)) => {
527                self.can_conform_ty((candidate, candidate_final), (target, target_final))
528            }
529            (GenericArgumentId::Constant(candidate), GenericArgumentId::Constant(target)) => {
530                self.can_conform_const((candidate, candidate_final), (target, target_final))
531            }
532            (GenericArgumentId::Impl(candidate), GenericArgumentId::Impl(target)) => {
533                self.can_conform_impl((candidate, candidate_final), (target, target_final))
534            }
535            (GenericArgumentId::NegImpl(_), GenericArgumentId::NegImpl(_)) => {
536                CanConformResult::InferenceRequired
537            }
538            _ => CanConformResult::Rejected,
539        }
540    }
541
542    /// Checks if a generic arg [TypeId] of the candidate could be conformed to a generic arg
543    /// [TypeId] of the trait.
544    fn can_conform_ty(
545        &mut self,
546        (candidate_ty, mut candidate_final): (TypeId<'db>, bool),
547        (target_ty, mut target_final): (TypeId<'db>, bool),
548    ) -> CanConformResult {
549        if candidate_ty == target_ty {
550            return CanConformResult::Accepted;
551        }
552        candidate_final = candidate_final || candidate_ty.is_fully_concrete(self.db);
553        target_final = target_final || target_ty.is_var_free(self.db);
554        if candidate_final && target_final {
555            return CanConformResult::Rejected;
556        }
557        let target_long_ty = target_ty.long(self.db);
558
559        if let TypeLongId::Var(_) = target_long_ty {
560            return CanConformResult::InferenceRequired;
561        }
562
563        let long_ty_candidate = candidate_ty.long(self.db);
564
565        match (long_ty_candidate, target_long_ty) {
566            (TypeLongId::Concrete(candidate), TypeLongId::Concrete(target)) => {
567                if candidate.generic_type(self.db) != target.generic_type(self.db) {
568                    return CanConformResult::Rejected;
569                }
570
571                self.can_conform_generic_args(
572                    (&candidate.generic_args(self.db), candidate_final),
573                    (&target.generic_args(self.db), target_final),
574                )
575            }
576            (TypeLongId::Concrete(_), _) => CanConformResult::Rejected,
577            (TypeLongId::Tuple(candidate_tys), TypeLongId::Tuple(target_tys)) => {
578                if candidate_tys.len() != target_tys.len() {
579                    return CanConformResult::Rejected;
580                }
581
582                CanConformResult::fold(zip_eq(candidate_tys, target_tys).map(
583                    |(candidate_subty, target_subty)| {
584                        self.can_conform_ty(
585                            (*candidate_subty, candidate_final),
586                            (*target_subty, target_final),
587                        )
588                    },
589                ))
590            }
591            (TypeLongId::Tuple(_), _) => CanConformResult::Rejected,
592            (TypeLongId::Closure(candidate), TypeLongId::Closure(target)) => {
593                if candidate.params_location != target.params_location {
594                    return CanConformResult::Rejected;
595                }
596
597                let params_check = CanConformResult::fold(
598                    zip_eq(candidate.param_tys.clone(), target.param_tys.clone()).map(
599                        |(candidate_subty, target_subty)| {
600                            self.can_conform_ty(
601                                (candidate_subty, candidate_final),
602                                (target_subty, target_final),
603                            )
604                        },
605                    ),
606                );
607                if params_check == CanConformResult::Rejected {
608                    return CanConformResult::Rejected;
609                }
610                let captured_types_check = CanConformResult::fold(
611                    zip_eq(candidate.captured_types.clone(), target.captured_types.clone()).map(
612                        |(candidate_subty, target_subty)| {
613                            self.can_conform_ty(
614                                (candidate_subty, candidate_final),
615                                (target_subty, target_final),
616                            )
617                        },
618                    ),
619                );
620                if captured_types_check == CanConformResult::Rejected {
621                    return CanConformResult::Rejected;
622                }
623                let return_type_check = self.can_conform_ty(
624                    (candidate.ret_ty, candidate_final),
625                    (target.ret_ty, target_final),
626                );
627                if return_type_check == CanConformResult::Rejected {
628                    return CanConformResult::Rejected;
629                }
630                if params_check == CanConformResult::InferenceRequired
631                    || captured_types_check == CanConformResult::InferenceRequired
632                    || return_type_check == CanConformResult::InferenceRequired
633                {
634                    return CanConformResult::InferenceRequired;
635                }
636                CanConformResult::Accepted
637            }
638            (TypeLongId::Closure(_), _) => CanConformResult::Rejected,
639            (
640                TypeLongId::FixedSizeArray { type_id: candidate_type_id, size: candidate_size },
641                TypeLongId::FixedSizeArray { type_id: target_type_id, size: target_size },
642            ) => CanConformResult::fold([
643                self.can_conform_const(
644                    (*candidate_size, candidate_final),
645                    (*target_size, target_final),
646                ),
647                self.can_conform_ty(
648                    (*candidate_type_id, candidate_final),
649                    (*target_type_id, target_final),
650                ),
651            ]),
652            (TypeLongId::FixedSizeArray { type_id: _, size: _ }, _) => CanConformResult::Rejected,
653            (TypeLongId::Snapshot(candidate_inner_ty), TypeLongId::Snapshot(target_inner_ty)) => {
654                self.can_conform_ty(
655                    (*candidate_inner_ty, candidate_final),
656                    (*target_inner_ty, target_final),
657                )
658            }
659            (TypeLongId::Snapshot(_), _) => CanConformResult::Rejected,
660            (TypeLongId::GenericParameter(param), _) => {
661                let mut res = CanConformResult::Accepted;
662                // if param not in substitution add it otherwise make sure it equal target_ty
663                match self.substitution.entry(*param) {
664                    Entry::Occupied(entry) => {
665                        if let GenericArgumentId::Type(existing_ty) = entry.get() {
666                            if *existing_ty != target_ty {
667                                res = CanConformResult::Rejected;
668                            }
669                            if !existing_ty.is_var_free(self.db) {
670                                return CanConformResult::InferenceRequired;
671                            }
672                        } else {
673                            res = CanConformResult::Rejected;
674                        }
675                    }
676                    Entry::Vacant(e) => {
677                        e.insert(GenericArgumentId::Type(target_ty));
678                    }
679                }
680
681                if target_ty.is_var_free(self.db) {
682                    res
683                } else {
684                    CanConformResult::InferenceRequired
685                }
686            }
687            (
688                TypeLongId::Var(_)
689                | TypeLongId::ImplType(_)
690                | TypeLongId::Missing(_)
691                | TypeLongId::Coupon(_),
692                _,
693            ) => CanConformResult::InferenceRequired,
694        }
695    }
696
697    /// Checks if a generic arg [ImplId] of the candidate could be conformed to a generic arg
698    /// [ImplId] of the trait.
699    fn can_conform_impl(
700        &mut self,
701        (candidate_impl, mut candidate_final): (ImplId<'db>, bool),
702        (target_impl, mut target_final): (ImplId<'db>, bool),
703    ) -> CanConformResult {
704        let long_impl_trait = target_impl.long(self.db);
705        if candidate_impl == target_impl {
706            return CanConformResult::Accepted;
707        }
708        candidate_final = candidate_final || candidate_impl.is_fully_concrete(self.db);
709        target_final = target_final || target_impl.is_var_free(self.db);
710        if candidate_final && target_final {
711            return CanConformResult::Rejected;
712        }
713        if let ImplLongId::ImplVar(_) = long_impl_trait {
714            return CanConformResult::InferenceRequired;
715        }
716        match (candidate_impl.long(self.db), long_impl_trait) {
717            (ImplLongId::Concrete(candidate), ImplLongId::Concrete(target)) => {
718                let candidate = candidate.long(self.db);
719                let target = target.long(self.db);
720                if candidate.impl_def_id != target.impl_def_id {
721                    return CanConformResult::Rejected;
722                }
723                let candidate_args = candidate.generic_args.clone();
724                let target_args = target.generic_args.clone();
725                self.can_conform_generic_args(
726                    (&candidate_args, candidate_final),
727                    (&target_args, target_final),
728                )
729            }
730            (ImplLongId::Concrete(_), _) => CanConformResult::Rejected,
731            (ImplLongId::GenericParameter(param), _) => {
732                let mut res = CanConformResult::Accepted;
733                // if param not in substitution add it otherwise make sure it equal target_ty
734                match self.substitution.entry(*param) {
735                    Entry::Occupied(entry) => {
736                        if let GenericArgumentId::Impl(existing_impl) = entry.get() {
737                            if *existing_impl != target_impl {
738                                res = CanConformResult::Rejected;
739                            }
740                            if !existing_impl.is_var_free(self.db) {
741                                return CanConformResult::InferenceRequired;
742                            }
743                        } else {
744                            res = CanConformResult::Rejected;
745                        }
746                    }
747                    Entry::Vacant(e) => {
748                        e.insert(GenericArgumentId::Impl(target_impl));
749                    }
750                }
751
752                if target_impl.is_var_free(self.db) {
753                    res
754                } else {
755                    CanConformResult::InferenceRequired
756                }
757            }
758            (
759                ImplLongId::ImplVar(_)
760                | ImplLongId::ImplImpl(_)
761                | ImplLongId::SelfImpl(_)
762                | ImplLongId::GeneratedImpl(_),
763                _,
764            ) => CanConformResult::InferenceRequired,
765        }
766    }
767
768    /// Checks if a generic arg [ConstValueId] of the candidate could be conformed to a generic arg
769    /// [ConstValueId] of the trait.
770    fn can_conform_const(
771        &mut self,
772        (candidate_id, mut candidate_final): (ConstValueId<'db>, bool),
773        (target_id, mut target_final): (ConstValueId<'db>, bool),
774    ) -> CanConformResult {
775        if candidate_id == target_id {
776            return CanConformResult::Accepted;
777        }
778        candidate_final = candidate_final || candidate_id.is_fully_concrete(self.db);
779        target_final = target_final || target_id.is_var_free(self.db);
780        if candidate_final && target_final {
781            return CanConformResult::Rejected;
782        }
783        let target_long_const = target_id.long(self.db);
784        if let ConstValue::Var(_, _) = target_long_const {
785            return CanConformResult::InferenceRequired;
786        }
787        match (candidate_id.long(self.db), target_long_const) {
788            (
789                ConstValue::Int(big_int, type_id),
790                ConstValue::Int(target_big_int, target_type_id),
791            ) => {
792                if big_int != target_big_int {
793                    return CanConformResult::Rejected;
794                }
795                self.can_conform_ty((*type_id, candidate_final), (*target_type_id, target_final))
796            }
797            (ConstValue::Int(_, _), _) => CanConformResult::Rejected,
798            (
799                ConstValue::Struct(const_values, type_id),
800                ConstValue::Struct(target_const_values, target_type_id),
801            ) => {
802                if const_values.len() != target_const_values.len() {
803                    return CanConformResult::Rejected;
804                };
805                CanConformResult::fold(chain!(
806                    [self.can_conform_ty(
807                        (*type_id, candidate_final),
808                        (*target_type_id, target_final)
809                    )],
810                    zip_eq(const_values, target_const_values).map(
811                        |(const_value, target_const_value)| {
812                            self.can_conform_const(
813                                (*const_value, candidate_final),
814                                (*target_const_value, target_final),
815                            )
816                        }
817                    )
818                ))
819            }
820            (ConstValue::Struct(_, _), _) => CanConformResult::Rejected,
821
822            (
823                ConstValue::Enum(concrete_variant, const_value),
824                ConstValue::Enum(target_concrete_variant, target_const_value),
825            ) => CanConformResult::fold([
826                self.can_conform_ty(
827                    (concrete_variant.ty, candidate_final),
828                    (target_concrete_variant.ty, target_final),
829                ),
830                self.can_conform_const(
831                    (*const_value, candidate_final),
832                    (*target_const_value, target_final),
833                ),
834            ]),
835            (ConstValue::Enum(_, _), _) => CanConformResult::Rejected,
836            (ConstValue::NonZero(const_value), ConstValue::NonZero(target_const_value)) => self
837                .can_conform_const(
838                    (*const_value, candidate_final),
839                    (*target_const_value, target_final),
840                ),
841            (ConstValue::NonZero(_), _) => CanConformResult::Rejected,
842            (ConstValue::Generic(param), _) => {
843                let mut res = CanConformResult::Accepted;
844                match self.substitution.entry(*param) {
845                    Entry::Occupied(entry) => {
846                        if let GenericArgumentId::Constant(existing_const) = entry.get() {
847                            if *existing_const != target_id {
848                                res = CanConformResult::Rejected;
849                            }
850
851                            if !existing_const.is_var_free(self.db) {
852                                return CanConformResult::InferenceRequired;
853                            }
854                        } else {
855                            res = CanConformResult::Rejected;
856                        }
857                    }
858                    Entry::Vacant(e) => {
859                        e.insert(GenericArgumentId::Constant(target_id));
860                    }
861                }
862                if target_id.is_var_free(self.db) {
863                    res
864                } else {
865                    CanConformResult::InferenceRequired
866                }
867            }
868            (ConstValue::ImplConstant(_) | ConstValue::Var(_, _) | ConstValue::Missing(_), _) => {
869                CanConformResult::InferenceRequired
870            }
871        }
872    }
873}
874
875/// Trait for solver-related semantic queries.
876pub trait SemanticSolver<'db>: Database {
877    /// Returns the solution set for a canonical trait.
878    fn canonic_trait_solutions(
879        &'db self,
880        canonical_trait: CanonicalTrait<'db>,
881        lookup_context: ImplLookupContextId<'db>,
882        impl_type_bounds: BTreeMap<ImplTypeById<'db>, TypeId<'db>>,
883    ) -> Result<SolutionSet<'db, CanonicalImpl<'db>>, InferenceError<'db>> {
884        canonic_trait_solutions_tracked(
885            self.as_dyn_database(),
886            canonical_trait,
887            lookup_context,
888            impl_type_bounds,
889        )
890    }
891}
892impl<'db, T: Database + ?Sized> SemanticSolver<'db> for T {}