arithmetic_typing/arith/substitutions/
mod.rs

1//! Substitutions type and dependencies.
2
3use std::{
4    cmp::Ordering,
5    collections::{HashMap, HashSet},
6    iter, ops, ptr,
7};
8
9use crate::{
10    arith::{CompleteConstraints, Constraint},
11    error::{ErrorKind, ErrorLocation, OpErrors, TupleContext},
12    visit::{self, Visit, VisitMut},
13    Function, Object, PrimitiveType, Tuple, TupleLen, Type, TypeVar, UnknownLen,
14};
15
16mod fns;
17use self::fns::{MonoTypeTransformer, ParamMapping};
18
19#[cfg(test)]
20mod tests;
21
22#[derive(Debug, Clone, Copy)]
23enum LenErrorKind {
24    UnresolvedParam,
25    Mismatch,
26    Dynamic(TupleLen),
27}
28
29/// Set of equations and constraints on type variables.
30#[derive(Debug, Clone)]
31pub struct Substitutions<Prim: PrimitiveType> {
32    /// Number of type variables.
33    type_var_count: usize,
34    /// Type variable equations, encoded as `type_var[key] = value`.
35    eqs: HashMap<usize, Type<Prim>>,
36    /// Constraints on type variables.
37    constraints: HashMap<usize, CompleteConstraints<Prim>>,
38    /// Number of length variables.
39    len_var_count: usize,
40    /// Length variable equations.
41    length_eqs: HashMap<usize, TupleLen>,
42    /// Lengths that have static restriction.
43    static_lengths: HashSet<usize>,
44}
45
46impl<Prim: PrimitiveType> Default for Substitutions<Prim> {
47    fn default() -> Self {
48        Self {
49            type_var_count: 0,
50            eqs: HashMap::new(),
51            constraints: HashMap::new(),
52            len_var_count: 0,
53            length_eqs: HashMap::new(),
54            static_lengths: HashSet::new(),
55        }
56    }
57}
58
59impl<Prim: PrimitiveType> Substitutions<Prim> {
60    /// Inserts `constraints` for a type var with the specified index and all vars
61    /// it is equivalent to.
62    pub fn insert_constraint<C>(
63        &mut self,
64        var_idx: usize,
65        constraint: &C,
66        mut errors: OpErrors<'_, Prim>,
67    ) where
68        C: Constraint<Prim> + Clone,
69    {
70        for idx in self.equivalent_vars(var_idx) {
71            let mut current_constraints = self.constraints.remove(&idx).unwrap_or_default();
72            current_constraints.insert(constraint.clone(), self, errors.by_ref());
73            self.constraints.insert(idx, current_constraints);
74        }
75    }
76
77    /// Returns an object constraint associated with the specified type var. The returned type
78    /// is resolved.
79    pub(crate) fn object_constraint(&self, var: TypeVar) -> Option<Object<Prim>> {
80        if var.is_free() {
81            let mut ty = self.constraints.get(&var.index())?.object.clone()?;
82            self.resolver().visit_object_mut(&mut ty);
83            Some(ty)
84        } else {
85            None
86        }
87    }
88
89    /// Inserts an object constraint for a type var with the specified index.
90    pub(crate) fn insert_obj_constraint(
91        &mut self,
92        var_idx: usize,
93        constraint: &Object<Prim>,
94        mut errors: OpErrors<'_, Prim>,
95    ) {
96        // Check whether the constraint is recursive.
97        let mut checker = OccurrenceChecker::new(self, self.equivalent_vars(var_idx));
98        checker.visit_object(constraint);
99        if let Some(var) = checker.recursive_var {
100            self.handle_recursive_type(Type::Object(constraint.clone()), var, &mut errors);
101            return;
102        }
103
104        for idx in self.equivalent_vars(var_idx) {
105            let mut current_constraints = self.constraints.remove(&idx).unwrap_or_default();
106            current_constraints.insert_obj_constraint(constraint.clone(), self, errors.by_ref());
107            self.constraints.insert(idx, current_constraints);
108        }
109    }
110
111    fn handle_recursive_type(
112        &self,
113        ty: Type<Prim>,
114        recursive_var: usize,
115        errors: &mut OpErrors<'_, Prim>,
116    ) {
117        let mut resolved_ty = ty;
118        self.resolver().visit_type_mut(&mut resolved_ty);
119        TypeSanitizer::new(recursive_var).visit_type_mut(&mut resolved_ty);
120        errors.push(ErrorKind::RecursiveType(resolved_ty));
121    }
122
123    /// Returns type var indexes that are equivalent to the provided `var_idx`,
124    /// including `var_idx` itself.
125    fn equivalent_vars(&self, var_idx: usize) -> Vec<usize> {
126        let ty = Type::free_var(var_idx);
127        let mut ty = &ty;
128        let mut equivalent_vars = vec![];
129
130        while let Type::Var(var) = ty {
131            debug_assert!(var.is_free());
132            equivalent_vars.push(var.index());
133            if let Some(resolved) = self.eqs.get(&var.index()) {
134                ty = resolved;
135            } else {
136                break;
137            }
138        }
139        equivalent_vars
140    }
141
142    /// Marks `len` as static, i.e., not containing [`UnknownLen::Dynamic`] components.
143    #[allow(clippy::missing_panics_doc)]
144    pub fn apply_static_len(&mut self, len: TupleLen) -> Result<(), ErrorKind<Prim>> {
145        let resolved = self.resolve_len(len);
146        self.apply_static_len_inner(resolved)
147            .map_err(|err| match err {
148                LenErrorKind::UnresolvedParam => ErrorKind::UnresolvedParam,
149                LenErrorKind::Dynamic(len) => ErrorKind::DynamicLen(len),
150                LenErrorKind::Mismatch => unreachable!(),
151            })
152    }
153
154    // Assumes that `len` is resolved.
155    fn apply_static_len_inner(&mut self, len: TupleLen) -> Result<(), LenErrorKind> {
156        match len.components().0 {
157            None => Ok(()),
158            Some(UnknownLen::Dynamic) => Err(LenErrorKind::Dynamic(len)),
159            Some(UnknownLen::Var(var)) => {
160                if var.is_free() {
161                    self.static_lengths.insert(var.index());
162                    Ok(())
163                } else {
164                    Err(LenErrorKind::UnresolvedParam)
165                }
166            }
167        }
168    }
169
170    /// Resolves the type by following established equality links between type variables.
171    pub fn fast_resolve<'a>(&'a self, mut ty: &'a Type<Prim>) -> &'a Type<Prim> {
172        while let Type::Var(var) = ty {
173            if !var.is_free() {
174                // Bound variables cannot be resolved further.
175                break;
176            }
177
178            if let Some(resolved) = self.eqs.get(&var.index()) {
179                ty = resolved;
180            } else {
181                break;
182            }
183        }
184        ty
185    }
186
187    /// Returns a visitor that resolves the type using equality relations in these `Substitutions`.
188    pub fn resolver(&self) -> impl VisitMut<Prim> + '_ {
189        TypeResolver {
190            substitutions: self,
191        }
192    }
193
194    /// Resolves the provided `len` given length equations in this instance.
195    pub(crate) fn resolve_len(&self, len: TupleLen) -> TupleLen {
196        let mut resolved = len;
197        while let (Some(UnknownLen::Var(var)), exact) = resolved.components() {
198            if !var.is_free() {
199                break;
200            }
201
202            if let Some(eq_rhs) = self.length_eqs.get(&var.index()) {
203                resolved = *eq_rhs + exact;
204            } else {
205                break;
206            }
207        }
208        resolved
209    }
210
211    /// Creates and returns a new type variable.
212    pub fn new_type_var(&mut self) -> Type<Prim> {
213        let new_type = Type::free_var(self.type_var_count);
214        self.type_var_count += 1;
215        new_type
216    }
217
218    /// Creates and returns a new length variable.
219    pub(crate) fn new_len_var(&mut self) -> UnknownLen {
220        let new_length = UnknownLen::free_var(self.len_var_count);
221        self.len_var_count += 1;
222        new_length
223    }
224
225    /// Unifies types in `lhs` and `rhs`.
226    ///
227    /// - LHS corresponds to the lvalue in assignments and to called function signature in fn calls.
228    /// - RHS corresponds to the rvalue in assignments and to the type of the called function.
229    ///
230    /// If unification is impossible, the corresponding error(s) will be put into `errors`.
231    pub fn unify(&mut self, lhs: &Type<Prim>, rhs: &Type<Prim>, mut errors: OpErrors<'_, Prim>) {
232        let resolved_lhs = self.fast_resolve(lhs).clone();
233        let resolved_rhs = self.fast_resolve(rhs).clone();
234
235        // **NB.** LHS and RHS should never switch sides; the side is important for
236        // accuracy of error reporting, and for some cases of type inference (e.g.,
237        // instantiation of parametric functions).
238        match (&resolved_lhs, &resolved_rhs) {
239            // Variables should be assigned *before* the equality check and dealing with `Any`
240            // to account for `Var <- Any` assignment.
241            (Type::Var(var), ty) => {
242                if var.is_free() {
243                    self.unify_var(var.index(), ty, true, errors);
244                } else {
245                    errors.push(ErrorKind::UnresolvedParam);
246                }
247            }
248
249            // This takes care of `Any` types because they are equal to anything.
250            (ty, other_ty) if ty == other_ty => {
251                // We already know that types are equal.
252            }
253
254            (Type::Dyn(constraints), ty) => {
255                constraints.inner.apply_all(ty, self, errors);
256            }
257
258            (ty, Type::Var(var)) => {
259                if var.is_free() {
260                    self.unify_var(var.index(), ty, false, errors);
261                } else {
262                    errors.push(ErrorKind::UnresolvedParam);
263                }
264            }
265
266            (Type::Tuple(lhs_tuple), Type::Tuple(rhs_tuple)) => {
267                self.unify_tuples(lhs_tuple, rhs_tuple, TupleContext::Generic, errors);
268            }
269            (Type::Object(lhs_obj), Type::Object(rhs_obj)) => {
270                self.unify_objects(lhs_obj, rhs_obj, errors);
271            }
272
273            (Type::Function(lhs_fn), Type::Function(rhs_fn)) => {
274                self.unify_fn_types(lhs_fn, rhs_fn, errors);
275            }
276
277            (ty, other_ty) => {
278                let mut resolver = self.resolver();
279                let mut ty = ty.clone();
280                resolver.visit_type_mut(&mut ty);
281                let mut other_ty = other_ty.clone();
282                resolver.visit_type_mut(&mut other_ty);
283                errors.push(ErrorKind::TypeMismatch(ty, other_ty));
284            }
285        }
286    }
287
288    fn unify_tuples(
289        &mut self,
290        lhs: &Tuple<Prim>,
291        rhs: &Tuple<Prim>,
292        context: TupleContext,
293        mut errors: OpErrors<'_, Prim>,
294    ) {
295        let resolved_len = self.unify_lengths(lhs.len(), rhs.len(), context);
296        let resolved_len = match resolved_len {
297            Ok(len) => len,
298            Err(err) => {
299                self.unify_tuples_after_error(lhs, rhs, &err, context, errors.by_ref());
300                errors.push(err);
301                return;
302            }
303        };
304
305        if let (None, exact) = resolved_len.components() {
306            self.unify_tuple_elements(lhs.iter(exact), rhs.iter(exact), context, errors);
307        } else {
308            // TODO: is this always applicable?
309            for (lhs_elem, rhs_elem) in lhs.equal_elements_dyn(rhs) {
310                let elem_errors = errors.with_location(match context {
311                    TupleContext::Generic => ErrorLocation::TupleElement(None),
312                    TupleContext::FnArgs => ErrorLocation::FnArg(None),
313                });
314                self.unify(lhs_elem, rhs_elem, elem_errors);
315            }
316        }
317    }
318
319    #[inline]
320    fn unify_tuple_elements<'it>(
321        &mut self,
322        lhs_elements: impl Iterator<Item = &'it Type<Prim>>,
323        rhs_elements: impl Iterator<Item = &'it Type<Prim>>,
324        context: TupleContext,
325        mut errors: OpErrors<'_, Prim>,
326    ) {
327        for (i, (lhs_elem, rhs_elem)) in lhs_elements.zip(rhs_elements).enumerate() {
328            let location = context.element(i);
329            self.unify(lhs_elem, rhs_elem, errors.with_location(location));
330        }
331    }
332
333    /// Tries to unify tuple elements after an error has occurred when unifying their lengths.
334    fn unify_tuples_after_error(
335        &mut self,
336        lhs: &Tuple<Prim>,
337        rhs: &Tuple<Prim>,
338        err: &ErrorKind<Prim>,
339        context: TupleContext,
340        errors: OpErrors<'_, Prim>,
341    ) {
342        let (lhs_len, rhs_len) = match err {
343            ErrorKind::TupleLenMismatch {
344                lhs: lhs_len,
345                rhs: rhs_len,
346                ..
347            } => (*lhs_len, *rhs_len),
348            _ => return,
349        };
350        let (lhs_var, lhs_exact) = lhs_len.components();
351        let (rhs_var, rhs_exact) = rhs_len.components();
352
353        match (lhs_var, rhs_var) {
354            (None, None) => {
355                // We've attempted to unify tuples with different known lengths.
356                // Iterate over common elements and unify them.
357                debug_assert_ne!(lhs_exact, rhs_exact);
358                self.unify_tuple_elements(
359                    lhs.iter(lhs_exact),
360                    rhs.iter(rhs_exact),
361                    context,
362                    errors,
363                );
364            }
365
366            (None, Some(UnknownLen::Dynamic)) => {
367                // We've attempted to unify static LHS with a dynamic RHS
368                // e.g., `(x, y) = filter(...)`.
369                self.unify_tuple_elements(
370                    lhs.iter(lhs_exact),
371                    rhs.iter(rhs_exact),
372                    context,
373                    errors,
374                );
375            }
376
377            _ => { /* Do nothing. */ }
378        }
379    }
380
381    /// Returns the resolved length that `lhs` and `rhs` are equal to.
382    fn unify_lengths(
383        &mut self,
384        lhs: TupleLen,
385        rhs: TupleLen,
386        context: TupleContext,
387    ) -> Result<TupleLen, ErrorKind<Prim>> {
388        let resolved_lhs = self.resolve_len(lhs);
389        let resolved_rhs = self.resolve_len(rhs);
390
391        self.unify_lengths_inner(resolved_lhs, resolved_rhs)
392            .map_err(|err| match err {
393                LenErrorKind::UnresolvedParam => ErrorKind::UnresolvedParam,
394                LenErrorKind::Mismatch => ErrorKind::TupleLenMismatch {
395                    lhs: resolved_lhs,
396                    rhs: resolved_rhs,
397                    context,
398                },
399                LenErrorKind::Dynamic(len) => ErrorKind::DynamicLen(len),
400            })
401    }
402
403    fn unify_lengths_inner(
404        &mut self,
405        resolved_lhs: TupleLen,
406        resolved_rhs: TupleLen,
407    ) -> Result<TupleLen, LenErrorKind> {
408        let (lhs_var, lhs_exact) = resolved_lhs.components();
409        let (rhs_var, rhs_exact) = resolved_rhs.components();
410
411        // First, consider a case when at least one of resolved lengths is exact.
412        let (lhs_var, rhs_var) = match (lhs_var, rhs_var) {
413            (Some(lhs_var), Some(rhs_var)) => (lhs_var, rhs_var),
414
415            (Some(lhs_var), None) if rhs_exact >= lhs_exact => {
416                return self
417                    .unify_simple_length(lhs_var, TupleLen::from(rhs_exact - lhs_exact), true)
418                    .map(|len| len + lhs_exact);
419            }
420            (None, Some(rhs_var)) if lhs_exact >= rhs_exact => {
421                return self
422                    .unify_simple_length(rhs_var, TupleLen::from(lhs_exact - rhs_exact), false)
423                    .map(|len| len + rhs_exact);
424            }
425
426            (None, None) if lhs_exact == rhs_exact => return Ok(TupleLen::from(lhs_exact)),
427
428            _ => return Err(LenErrorKind::Mismatch),
429        };
430
431        match lhs_exact.cmp(&rhs_exact) {
432            Ordering::Equal => self.unify_simple_length(lhs_var, TupleLen::from(rhs_var), true),
433            Ordering::Greater => {
434                let reduced = lhs_var + (lhs_exact - rhs_exact);
435                self.unify_simple_length(rhs_var, reduced, false)
436                    .map(|len| len + rhs_exact)
437            }
438            Ordering::Less => {
439                let reduced = rhs_var + (rhs_exact - lhs_exact);
440                self.unify_simple_length(lhs_var, reduced, true)
441                    .map(|len| len + lhs_exact)
442            }
443        }
444    }
445
446    fn unify_simple_length(
447        &mut self,
448        simple_len: UnknownLen,
449        source: TupleLen,
450        is_lhs: bool,
451    ) -> Result<TupleLen, LenErrorKind> {
452        match simple_len {
453            UnknownLen::Var(var) if var.is_free() => self.unify_var_length(var.index(), source),
454            UnknownLen::Dynamic => self.unify_dyn_length(source, is_lhs),
455            _ => Err(LenErrorKind::UnresolvedParam),
456        }
457    }
458
459    #[inline]
460    fn unify_var_length(
461        &mut self,
462        var_idx: usize,
463        source: TupleLen,
464    ) -> Result<TupleLen, LenErrorKind> {
465        // Check that the source is valid.
466        match source.components() {
467            (Some(UnknownLen::Var(var)), _) if !var.is_free() => Err(LenErrorKind::UnresolvedParam),
468
469            // Special case is uniting a var with self.
470            (Some(UnknownLen::Var(var)), offset) if var.index() == var_idx => {
471                if offset == 0 {
472                    Ok(source)
473                } else {
474                    Err(LenErrorKind::Mismatch)
475                }
476            }
477
478            _ => {
479                if self.static_lengths.contains(&var_idx) {
480                    self.apply_static_len_inner(source)?;
481                }
482                self.length_eqs.insert(var_idx, source);
483                Ok(source)
484            }
485        }
486    }
487
488    #[inline]
489    fn unify_dyn_length(
490        &mut self,
491        source: TupleLen,
492        is_lhs: bool,
493    ) -> Result<TupleLen, LenErrorKind> {
494        if is_lhs {
495            Ok(source) // assignment to dyn length always succeeds
496        } else {
497            let source_var_idx = match source.components() {
498                (Some(UnknownLen::Var(var)), 0) if var.is_free() => var.index(),
499                (Some(UnknownLen::Dynamic), 0) => return Ok(source),
500                _ => return Err(LenErrorKind::Mismatch),
501            };
502            self.unify_var_length(source_var_idx, UnknownLen::Dynamic.into())
503        }
504    }
505
506    fn unify_objects(
507        &mut self,
508        lhs: &Object<Prim>,
509        rhs: &Object<Prim>,
510        mut errors: OpErrors<'_, Prim>,
511    ) {
512        let lhs_fields: HashSet<_> = lhs.field_names().collect();
513        let rhs_fields: HashSet<_> = rhs.field_names().collect();
514
515        if lhs_fields == rhs_fields {
516            for (field_name, ty) in lhs.iter() {
517                let rhs_ty = rhs.field(field_name).unwrap();
518                self.unify(ty, rhs_ty, errors.with_location(field_name));
519            }
520        } else {
521            errors.push(ErrorKind::FieldsMismatch {
522                lhs_fields: lhs_fields.into_iter().map(String::from).collect(),
523                rhs_fields: rhs_fields.into_iter().map(String::from).collect(),
524            });
525        }
526    }
527
528    fn unify_fn_types(
529        &mut self,
530        lhs: &Function<Prim>,
531        rhs: &Function<Prim>,
532        mut errors: OpErrors<'_, Prim>,
533    ) {
534        if lhs.is_parametric() {
535            errors.push(ErrorKind::UnsupportedParam);
536            return;
537        }
538
539        let instantiated_lhs = self.instantiate_function(lhs);
540        let instantiated_rhs = self.instantiate_function(rhs);
541
542        // Swapping args is intentional. To see why, consider a function
543        // `fn(T, U) -> V` called as `fn(A, B) -> C` (`T`, ... `C` are types).
544        // In this case, the first arg of actual type `A` will be assigned to type `T`
545        // (i.e., `T` is LHS and `A` is RHS); same with `U` and `B`. In contrast,
546        // after function execution the return value of type `V` will be assigned
547        // to type `C`. (I.e., unification of return values is not swapped.)
548        self.unify_tuples(
549            &instantiated_rhs.args,
550            &instantiated_lhs.args,
551            TupleContext::FnArgs,
552            errors.by_ref(),
553        );
554
555        self.unify(
556            &instantiated_lhs.return_type,
557            &instantiated_rhs.return_type,
558            errors.with_location(ErrorLocation::FnReturnType),
559        );
560    }
561
562    /// Instantiates a functional type by replacing all type arguments with new type vars.
563    fn instantiate_function(&mut self, fn_type: &Function<Prim>) -> Function<Prim> {
564        if !fn_type.is_parametric() {
565            // Fast path: just clone the function type.
566            return fn_type.clone();
567        }
568        let fn_params = fn_type.params.as_ref().expect("fn with params");
569
570        // Map type vars in the function into newly created type vars.
571        let mapping = ParamMapping {
572            types: fn_params
573                .type_params
574                .iter()
575                .enumerate()
576                .map(|(i, (var_idx, _))| (*var_idx, self.type_var_count + i))
577                .collect(),
578            lengths: fn_params
579                .len_params
580                .iter()
581                .enumerate()
582                .map(|(i, (var_idx, _))| (*var_idx, self.len_var_count + i))
583                .collect(),
584        };
585        self.type_var_count += fn_params.type_params.len();
586        self.len_var_count += fn_params.len_params.len();
587
588        let mut instantiated_fn_type = fn_type.clone();
589        MonoTypeTransformer::transform(&mapping, &mut instantiated_fn_type);
590
591        // Copy constraints on the newly generated const and type vars from the function definition.
592        for (original_idx, is_static) in &fn_params.len_params {
593            if *is_static {
594                let new_idx = mapping.lengths[original_idx];
595                self.static_lengths.insert(new_idx);
596            }
597        }
598        for (original_idx, constraints) in &fn_params.type_params {
599            let new_idx = mapping.types[original_idx];
600            let mono_constraints =
601                MonoTypeTransformer::transform_constraints(&mapping, constraints);
602            self.constraints.insert(new_idx, mono_constraints);
603        }
604
605        instantiated_fn_type
606    }
607
608    /// Unifies a type variable with the specified index and the specified type.
609    fn unify_var(
610        &mut self,
611        var_idx: usize,
612        ty: &Type<Prim>,
613        is_lhs: bool,
614        mut errors: OpErrors<'_, Prim>,
615    ) {
616        // Variables should be resolved in `unify`.
617        debug_assert!(is_lhs || !matches!(ty, Type::Any | Type::Dyn(_)));
618        debug_assert!(!self.eqs.contains_key(&var_idx));
619        debug_assert!(if let Type::Var(var) = ty {
620            !self.eqs.contains_key(&var.index())
621        } else {
622            true
623        });
624
625        if let Type::Var(var) = ty {
626            if !var.is_free() {
627                errors.push(ErrorKind::UnresolvedParam);
628                return;
629            } else if var.index() == var_idx {
630                return;
631            }
632        }
633
634        let mut checker = OccurrenceChecker::new(self, iter::once(var_idx));
635        checker.visit_type(ty);
636
637        if let Some(var) = checker.recursive_var {
638            self.handle_recursive_type(ty.clone(), var, &mut errors);
639        } else {
640            if let Some(constraints) = self.constraints.get(&var_idx).cloned() {
641                constraints.apply_all(ty, self, errors);
642            }
643
644            let mut ty = ty.clone();
645            if !is_lhs {
646                // We need to swap `any` types / lengths with new vars so that this type
647                // can be specified further.
648                TypeSpecifier::new(self).visit_type_mut(&mut ty);
649            }
650            self.eqs.insert(var_idx, ty);
651        }
652    }
653}
654
655/// Checks if a type variable with the specified index is present in `ty`. This method
656/// is used to check that types are not recursive.
657#[derive(Debug)]
658struct OccurrenceChecker<'a, Prim: PrimitiveType> {
659    substitutions: &'a Substitutions<Prim>,
660    var_indexes: HashSet<usize>,
661    recursive_var: Option<usize>,
662}
663
664impl<'a, Prim: PrimitiveType> OccurrenceChecker<'a, Prim> {
665    fn new(
666        substitutions: &'a Substitutions<Prim>,
667        var_indexes: impl IntoIterator<Item = usize>,
668    ) -> Self {
669        Self {
670            substitutions,
671            var_indexes: var_indexes.into_iter().collect(),
672            recursive_var: None,
673        }
674    }
675}
676
677impl<Prim: PrimitiveType> Visit<Prim> for OccurrenceChecker<'_, Prim> {
678    fn visit_type(&mut self, ty: &Type<Prim>) {
679        if self.recursive_var.is_some() {
680            // Skip recursion; we already have our answer at this point.
681        } else {
682            visit::visit_type(self, ty);
683        }
684    }
685
686    fn visit_var(&mut self, var: TypeVar) {
687        let var_idx = var.index();
688        if self.var_indexes.contains(&var_idx) {
689            self.recursive_var = Some(var_idx);
690        } else if let Some(ty) = self.substitutions.eqs.get(&var_idx) {
691            self.visit_type(ty);
692        }
693    }
694}
695
696/// Removes excessive information about type vars. This method is used when types are
697/// provided to `Error`.
698#[derive(Debug)]
699struct TypeSanitizer {
700    fixed_idx: usize,
701}
702
703impl TypeSanitizer {
704    fn new(fixed_idx: usize) -> Self {
705        Self { fixed_idx }
706    }
707}
708
709impl<Prim: PrimitiveType> VisitMut<Prim> for TypeSanitizer {
710    fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
711        match ty {
712            Type::Var(var) if var.index() == self.fixed_idx => {
713                *ty = Type::param(0);
714            }
715            _ => visit::visit_type_mut(self, ty),
716        }
717    }
718}
719
720/// Visitor that performs type resolution based on `Substitutions`.
721#[derive(Debug, Clone, Copy)]
722struct TypeResolver<'a, Prim: PrimitiveType> {
723    substitutions: &'a Substitutions<Prim>,
724}
725
726impl<Prim: PrimitiveType> VisitMut<Prim> for TypeResolver<'_, Prim> {
727    fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
728        let fast_resolved = self.substitutions.fast_resolve(ty);
729        if !ptr::eq(ty, fast_resolved) {
730            *ty = fast_resolved.clone();
731        }
732        visit::visit_type_mut(self, ty);
733    }
734
735    fn visit_middle_len_mut(&mut self, len: &mut TupleLen) {
736        *len = self.substitutions.resolve_len(*len);
737    }
738}
739
740#[derive(Debug, Clone, Copy, PartialEq)]
741enum Variance {
742    Co,
743    Contra,
744}
745
746impl ops::Not for Variance {
747    type Output = Self;
748
749    fn not(self) -> Self {
750        match self {
751            Self::Co => Self::Contra,
752            Self::Contra => Self::Co,
753        }
754    }
755}
756
757/// Visitor that swaps `any` types / lengths with new vars, but only if they are in a covariant
758/// position (return types, args of arg functions, etc.).
759///
760/// This is used when assigning to a type containing `any`.
761#[derive(Debug)]
762struct TypeSpecifier<'a, Prim: PrimitiveType> {
763    substitutions: &'a mut Substitutions<Prim>,
764    variance: Variance,
765}
766
767impl<'a, Prim: PrimitiveType> TypeSpecifier<'a, Prim> {
768    fn new(substitutions: &'a mut Substitutions<Prim>) -> Self {
769        Self {
770            substitutions,
771            variance: Variance::Co,
772        }
773    }
774}
775
776impl<Prim: PrimitiveType> VisitMut<Prim> for TypeSpecifier<'_, Prim> {
777    fn visit_type_mut(&mut self, ty: &mut Type<Prim>) {
778        match ty {
779            Type::Any if self.variance == Variance::Co => {
780                *ty = self.substitutions.new_type_var();
781            }
782
783            Type::Dyn(constraints) if self.variance == Variance::Co => {
784                let var_idx = self.substitutions.type_var_count;
785                self.substitutions
786                    .constraints
787                    .insert(var_idx, constraints.inner.clone());
788                *ty = Type::free_var(var_idx);
789                self.substitutions.type_var_count += 1;
790            }
791
792            _ => visit::visit_type_mut(self, ty),
793        }
794    }
795
796    fn visit_middle_len_mut(&mut self, len: &mut TupleLen) {
797        if self.variance != Variance::Co {
798            return;
799        }
800        if let (Some(var_len @ UnknownLen::Dynamic), _) = len.components_mut() {
801            *var_len = self.substitutions.new_len_var();
802        }
803    }
804
805    fn visit_function_mut(&mut self, function: &mut Function<Prim>) {
806        // Since the visiting order doesn't matter, we visit the return type (which preserves
807        // variance) first.
808        self.visit_type_mut(&mut function.return_type);
809
810        let old_variance = self.variance;
811        self.variance = !self.variance;
812        self.visit_tuple_mut(&mut function.args);
813        self.variance = old_variance;
814    }
815}