1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use std::sync::Arc;

use crate::{
    cast::Upcast,
    collections::Set,
    language::{CoreParameter, HasKind, Language},
    variable::CoreVariable,
    visit::CoreVisit,
};

/// Invoked for each variable that we find when folding, ignoring variables bound by binders
/// that we traverse. The arguments are as follows:
///
/// * ParameterKind -- the kind of term in which the variable appeared (type vs lifetime, etc)
/// * Variable -- the variable we encountered
pub type SubstitutionFn<'a, L: Language> =
    &'a mut dyn FnMut(CoreVariable<L>) -> Option<CoreParameter<L>>;

pub trait CoreFold<L: Language>: Sized + CoreVisit<L> {
    /// Replace uses of variables with values from the substitution.
    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self;

    /// Produce a version of this term where any debruijn indices which appear free are incremented by one.
    fn shift_in(&self) -> Self {
        self.substitute(&mut |v| Some(v.shift_in().upcast()))
    }

    /// Replace all appearances of free variable `v` with `p`.
    fn replace_free_var(
        &self,
        v: impl Upcast<CoreVariable<L>>,
        p: impl Upcast<CoreParameter<L>>,
    ) -> Self {
        let v: CoreVariable<L> = v.upcast();
        let p: CoreParameter<L> = p.upcast();
        assert!(v.is_free());
        assert!(v.kind() == p.kind());
        self.substitute(&mut |v1| if v == v1 { Some(p.clone()) } else { None })
    }
}

impl<L: Language, T: CoreFold<L>> CoreFold<L> for Vec<T> {
    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
        self.iter().map(|e| e.substitute(substitution_fn)).collect()
    }
}

impl<L: Language, T: CoreFold<L> + Ord> CoreFold<L> for Set<T> {
    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
        self.iter().map(|e| e.substitute(substitution_fn)).collect()
    }
}

impl<L: Language, T: CoreFold<L>> CoreFold<L> for Option<T> {
    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
        self.as_ref().map(|e| e.substitute(substitution_fn))
    }
}

impl<L: Language, T: CoreFold<L>> CoreFold<L> for Arc<T> {
    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
        let data = T::substitute(self, substitution_fn);
        Arc::new(data)
    }
}

impl<L: Language> CoreFold<L> for usize {
    fn substitute(&self, _substitution_fn: SubstitutionFn<'_, L>) -> Self {
        *self
    }
}

impl<L: Language> CoreFold<L> for u32 {
    fn substitute(&self, _substitution_fn: SubstitutionFn<'_, L>) -> Self {
        *self
    }
}

impl<L: Language> CoreFold<L> for () {
    fn substitute(&self, _substitution_fn: SubstitutionFn<'_, L>) -> Self {}
}

impl<L: Language, A: CoreFold<L>, B: CoreFold<L>> CoreFold<L> for (A, B) {
    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
        let (a, b) = self;
        (a.substitute(substitution_fn), b.substitute(substitution_fn))
    }
}

impl<L: Language, A: CoreFold<L>, B: CoreFold<L>, C: CoreFold<L>> CoreFold<L> for (A, B, C) {
    fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
        let (a, b, c) = self;
        (
            a.substitute(substitution_fn),
            b.substitute(substitution_fn),
            c.substitute(substitution_fn),
        )
    }
}