use std::sync::atomic::{AtomicUsize, Ordering};
use anyhow::bail;
use lazy_static::lazy_static;
use crate::{
cast::{Downcast, DowncastFrom, DowncastTo, To, Upcast, UpcastFrom},
fold::CoreFold,
fold::SubstitutionFn,
language::{CoreKind, CoreParameter, HasKind, Language},
substitution::CoreSubstitution,
variable::{CoreBoundVar, CoreVariable, DebruijnIndex, VarIndex},
visit::CoreVisit,
Fallible,
};
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct CoreBinder<L: Language, T> {
kinds: Vec<CoreKind<L>>,
term: T,
}
impl<L: Language, T: CoreFold<L>> CoreBinder<L, T> {
pub fn open(&self) -> (Vec<CoreBoundVar<L>>, T) {
let (bound_vars, substitution): (Vec<CoreBoundVar<L>>, CoreSubstitution<L>) = self
.kinds
.iter()
.zip(0..)
.map(|(kind, index)| {
let old_bound_var = CoreBoundVar {
debruijn: Some(DebruijnIndex::INNERMOST),
var_index: VarIndex { index },
kind: *kind,
};
let new_bound_var = CoreBoundVar::fresh(*kind);
(new_bound_var, (old_bound_var, new_bound_var))
})
.unzip();
(bound_vars, substitution.apply(&self.term))
}
pub fn dummy(term: T) -> Self {
let v: Vec<CoreVariable<L>> = vec![];
Self::new(v, term)
}
pub fn new(variables: impl Upcast<Vec<CoreVariable<L>>>, term: T) -> Self {
let variables: Vec<CoreVariable<L>> = variables.upcast();
let (kinds, substitution): (Vec<CoreKind<L>>, CoreSubstitution<L>) = variables
.iter()
.zip(0..)
.map(|(old_bound_var, index)| {
let old_bound_var: CoreVariable<L> = old_bound_var.upcast();
assert!(old_bound_var.is_free());
let new_bound_var: CoreParameter<L> = CoreBoundVar {
debruijn: Some(DebruijnIndex::INNERMOST),
var_index: VarIndex { index },
kind: old_bound_var.kind(),
}
.upcast();
(old_bound_var.kind(), (old_bound_var, new_bound_var))
})
.unzip();
let term = substitution.apply(&term);
CoreBinder { kinds, term }
}
pub fn mentioned(variables: impl Upcast<Vec<CoreVariable<L>>>, term: T) -> Self {
let mut variables: Vec<CoreVariable<L>> = variables.upcast();
let fv = term.free_variables();
variables.retain(|v| fv.contains(v));
let variables: Vec<CoreVariable<L>> = variables.into_iter().collect();
CoreBinder::new(variables, term)
}
pub fn into<U>(self) -> CoreBinder<L, U>
where
T: Into<U>,
{
CoreBinder {
kinds: self.kinds,
term: self.term.into(),
}
}
pub fn len(&self) -> usize {
self.kinds.len()
}
pub fn is_empty(&self) -> bool {
self.kinds.is_empty()
}
pub fn instantiate_with(&self, parameters: &[impl Upcast<CoreParameter<L>>]) -> Fallible<T> {
if parameters.len() != self.kinds.len() {
bail!("wrong number of parameters");
}
for ((p, k), i) in parameters.iter().zip(&self.kinds).zip(0..) {
let p: CoreParameter<L> = p.upcast();
if p.kind() != *k {
bail!(
"parameter {i} has kind {:?} but should have kind {:?}",
p.kind(),
k
);
}
}
Ok(self.instantiate(|_kind, index| parameters[index.index].to()))
}
pub fn instantiate(&self, mut op: impl FnMut(CoreKind<L>, VarIndex) -> CoreParameter<L>) -> T {
let substitution: Vec<CoreParameter<L>> = self
.kinds
.iter()
.zip(0..)
.map(|(&kind, index)| op(kind, VarIndex { index }))
.collect();
self.term.substitute(&mut |var| match var {
CoreVariable::BoundVar(CoreBoundVar {
debruijn: Some(DebruijnIndex::INNERMOST),
var_index,
kind: _,
}) => Some(substitution[var_index.index].clone()),
_ => None,
})
}
pub fn peek(&self) -> &T {
&self.term
}
pub fn kinds(&self) -> &[CoreKind<L>] {
&self.kinds
}
pub fn map<U: CoreFold<L>>(&self, op: impl FnOnce(T) -> U) -> CoreBinder<L, U> {
let (vars, t) = self.open();
let u = op(t);
CoreBinder::new(vars, u)
}
}
impl<L: Language> CoreBoundVar<L> {
pub fn fresh(kind: CoreKind<L>) -> Self {
lazy_static! {
static ref COUNTER: AtomicUsize = AtomicUsize::new(0);
}
let index = COUNTER.fetch_add(1, Ordering::SeqCst);
let var_index = VarIndex { index };
CoreBoundVar {
debruijn: None,
var_index,
kind,
}
}
}
impl<L: Language, T: CoreVisit<L>> CoreVisit<L> for CoreBinder<L, T> {
fn free_variables(&self) -> Vec<CoreVariable<L>> {
self.term.free_variables()
}
fn size(&self) -> usize {
self.term.size()
}
fn assert_valid(&self) {
self.term.assert_valid();
}
}
impl<L: Language, T: CoreFold<L>> CoreFold<L> for CoreBinder<L, T> {
fn substitute(&self, substitution_fn: SubstitutionFn<'_, L>) -> Self {
let term = self.term.substitute(&mut |v| {
let v1 = v.shift_out()?;
let parameter = substitution_fn(v1)?;
Some(parameter.shift_in())
});
CoreBinder {
kinds: self.kinds.clone(),
term,
}
}
fn shift_in(&self) -> Self {
let term = self.term.shift_in();
CoreBinder {
kinds: self.kinds.clone(),
term,
}
}
}
impl<L: Language, T, U> UpcastFrom<CoreBinder<L, T>> for CoreBinder<L, U>
where
T: Clone,
U: Clone,
T: Upcast<U>,
{
fn upcast_from(term: CoreBinder<L, T>) -> Self {
let CoreBinder { kinds, term } = term;
CoreBinder {
kinds,
term: term.upcast(),
}
}
}
impl<L: Language, T, U> DowncastTo<CoreBinder<L, T>> for CoreBinder<L, U>
where
T: DowncastFrom<U>,
{
fn downcast_to(&self) -> Option<CoreBinder<L, T>> {
let CoreBinder { kinds, term } = self;
let term = term.downcast()?;
Some(CoreBinder {
kinds: kinds.clone(),
term,
})
}
}
impl<L: Language, T> std::fmt::Debug for CoreBinder<L, T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<")?;
for (kind, i) in self.kinds.iter().zip(0..) {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{:?}", kind)?;
}
write!(f, "> ")?;
write!(f, "{:?}", &self.term)?;
Ok(())
}
}