formality_core/
fold.rs

1use std::sync::Arc;
2
3use crate::{
4    cast::Upcast,
5    collections::Set,
6    language::{CoreParameter, HasKind, Language},
7    variable::CoreVariable,
8    visit::CoreVisit,
9};
10
11/// Invoked for each variable that we find when folding, ignoring variables bound by binders
12/// that we traverse. The arguments are as follows:
13///
14/// * ParameterKind -- the kind of term in which the variable appeared (type vs lifetime, etc)
15/// * Variable -- the variable we encountered
16pub type SubstitutionFn<'a, L: Language> =
17    &'a mut dyn FnMut(CoreVariable<L>) -> Option<CoreParameter<L>>;
18
19pub trait CoreFold<L: Language>: Sized + CoreVisit<L> {
20    /// Replace uses of variables with values from the substitution.
21    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self;
22
23    /// Produce a version of this term where any debruijn indices which appear free are incremented by one.
24    fn shift_in(&self) -> Self {
25        self.substitute(&mut |v| Some(v.shift_in().upcast()))
26    }
27
28    /// Replace all appearances of free variable `v` with `p`.
29    fn replace_free_var(
30        &self,
31        v: impl Upcast<CoreVariable<L>>,
32        p: impl Upcast<CoreParameter<L>>,
33    ) -> Self {
34        let v: CoreVariable<L> = v.upcast();
35        let p: CoreParameter<L> = p.upcast();
36        assert!(v.is_free());
37        assert!(v.kind() == p.kind());
38        self.substitute(&mut |v1| if v == v1 { Some(p.clone()) } else { None })
39    }
40}
41
42impl<L: Language, T: CoreFold<L>> CoreFold<L> for Vec<T> {
43    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
44        self.iter().map(|e| e.substitute(substitution_fn)).collect()
45    }
46}
47
48impl<L: Language, T: CoreFold<L> + Ord> CoreFold<L> for Set<T> {
49    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
50        self.iter().map(|e| e.substitute(substitution_fn)).collect()
51    }
52}
53
54impl<L: Language, T: CoreFold<L>> CoreFold<L> for Option<T> {
55    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
56        self.as_ref().map(|e| e.substitute(substitution_fn))
57    }
58}
59
60impl<L: Language, T: CoreFold<L>> CoreFold<L> for Arc<T> {
61    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
62        let data = T::substitute(self, substitution_fn);
63        Arc::new(data)
64    }
65}
66
67impl<L: Language> CoreFold<L> for usize {
68    fn substitute(&self, _substitution_fn: SubstitutionFn<'_, L>) -> Self {
69        *self
70    }
71}
72
73impl<L: Language> CoreFold<L> for u32 {
74    fn substitute(&self, _substitution_fn: SubstitutionFn<'_, L>) -> Self {
75        *self
76    }
77}
78
79impl<L: Language> CoreFold<L> for () {
80    fn substitute(&self, _substitution_fn: SubstitutionFn<'_, L>) -> Self {}
81}
82
83impl<L: Language, A: CoreFold<L>, B: CoreFold<L>> CoreFold<L> for (A, B) {
84    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
85        let (a, b) = self;
86        (a.substitute(substitution_fn), b.substitute(substitution_fn))
87    }
88}
89
90impl<L: Language, A: CoreFold<L>, B: CoreFold<L>, C: CoreFold<L>> CoreFold<L> for (A, B, C) {
91    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
92        let (a, b, c) = self;
93        (
94            a.substitute(substitution_fn),
95            b.substitute(substitution_fn),
96            c.substitute(substitution_fn),
97        )
98    }
99}