use std::sync::Arc;
use crate::{
cast::Upcast,
collections::Set,
language::{CoreParameter, HasKind, Language},
variable::CoreVariable,
visit::CoreVisit,
};
pub type SubstitutionFn<'a, L: Language> =
&'a mut dyn FnMut(CoreVariable<L>) -> Option<CoreParameter<L>>;
pub trait CoreFold<L: Language>: Sized + CoreVisit<L> {
fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self;
fn shift_in(&self) -> Self {
self.substitute(&mut |v| Some(v.shift_in().upcast()))
}
fn replace_free_var(
&self,
v: impl Upcast<CoreVariable<L>>,
p: impl Upcast<CoreParameter<L>>,
) -> Self {
let v: CoreVariable<L> = v.upcast();
let p: CoreParameter<L> = p.upcast();
assert!(v.is_free());
assert!(v.kind() == p.kind());
self.substitute(&mut |v1| if v == v1 { Some(p.clone()) } else { None })
}
}
impl<L: Language, T: CoreFold<L>> CoreFold<L> for Vec<T> {
fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
self.iter().map(|e| e.substitute(substitution_fn)).collect()
}
}
impl<L: Language, T: CoreFold<L> + Ord> CoreFold<L> for Set<T> {
fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
self.iter().map(|e| e.substitute(substitution_fn)).collect()
}
}
impl<L: Language, T: CoreFold<L>> CoreFold<L> for Option<T> {
fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
self.as_ref().map(|e| e.substitute(substitution_fn))
}
}
impl<L: Language, T: CoreFold<L>> CoreFold<L> for Arc<T> {
fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
let data = T::substitute(self, substitution_fn);
Arc::new(data)
}
}
impl<L: Language> CoreFold<L> for usize {
fn substitute(&self, _substitution_fn: SubstitutionFn<'_, L>) -> Self {
*self
}
}
impl<L: Language> CoreFold<L> for u32 {
fn substitute(&self, _substitution_fn: SubstitutionFn<'_, L>) -> Self {
*self
}
}
impl<L: Language> CoreFold<L> for () {
fn substitute(&self, _substitution_fn: SubstitutionFn<'_, L>) -> Self {}
}
impl<L: Language, A: CoreFold<L>, B: CoreFold<L>> CoreFold<L> for (A, B) {
fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
let (a, b) = self;
(a.substitute(substitution_fn), b.substitute(substitution_fn))
}
}
impl<L: Language, A: CoreFold<L>, B: CoreFold<L>, C: CoreFold<L>> CoreFold<L> for (A, B, C) {
fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
let (a, b, c) = self;
(
a.substitute(substitution_fn),
b.substitute(substitution_fn),
c.substitute(substitution_fn),
)
}
}