use crate::ast::{self, Sym, Var};
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
#[repr(transparent)]
pub struct TermId(usize);
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
#[repr(transparent)]
pub struct ArgId(usize);
#[derive(Debug, Clone)]
pub struct TermArena {
terms: Vec<Term>,
args: Vec<TermId>,
}
impl TermArena {
pub fn new() -> Self {
Self {
terms: vec![],
args: vec![],
}
}
pub fn var(&mut self, var: Var) -> TermId {
let term = TermId(self.terms.len());
self.terms.push(Term::Var(var));
term
}
pub fn app(&mut self, functor: Sym, args: &[TermId]) -> TermId {
let term = TermId(self.terms.len());
let arg_start = self.args.len();
let arg_end = arg_start + args.len();
self.args.extend_from_slice(args);
self.terms.push(Term::App(AppTerm(
functor,
ArgRange {
start: arg_start,
end: arg_end,
},
)));
term
}
pub fn int(&mut self, int: i64) -> TermId {
let term = TermId(self.terms.len());
self.terms.push(Term::Int(int));
term
}
pub fn cut(&mut self) -> TermId {
let term = TermId(self.terms.len());
self.terms.push(Term::Cut);
term
}
pub fn instantiate_blueprint(
&mut self,
blueprint: &TermArena,
var_offset: usize,
) -> impl Fn(TermId) -> TermId {
let here = self.checkpoint();
self.terms
.extend(blueprint.terms.iter().map(|term| match term {
Term::Var(var) => Term::Var(var.offset(var_offset)),
Term::App(AppTerm(func, args)) => Term::App(AppTerm(
*func,
ArgRange {
start: args.start + here.args,
end: args.end + here.args,
},
)),
Term::Int(i) => Term::Int(*i),
Term::Cut => Term::Cut,
}));
self.args.extend(
blueprint
.args
.iter()
.map(|term_id| TermId(term_id.0 + here.terms)),
);
let term_offset = here.terms;
move |TermId(old)| TermId(old + term_offset)
}
pub fn insert_ast_term(&mut self, scratch: &mut Vec<TermId>, term: &ast::Term) -> TermId {
match term {
ast::Term::Var(v) => self.var(*v),
ast::Term::App(app) => self.insert_ast_appterm(scratch, app),
ast::Term::Int(i) => self.int(*i),
ast::Term::Cut => self.cut(),
}
}
pub fn insert_ast_appterm(&mut self, scratch: &mut Vec<TermId>, app: &ast::AppTerm) -> TermId {
let args_start = scratch.len();
for arg in &app.args {
let arg_term = self.insert_ast_term(scratch, arg);
scratch.push(arg_term);
}
let out = self.app(app.functor, &scratch[args_start..]);
scratch.truncate(args_start);
out
}
#[inline(always)]
pub fn get_arg(&self, arg_id: ArgId) -> TermId {
self.args[arg_id.0]
}
#[inline(always)]
pub fn get_args(&self, arg_range: ArgRange) -> impl Iterator<Item = TermId> + '_ {
self.args[arg_range.start..arg_range.end].iter().copied()
}
#[inline(always)]
pub fn get_args_fixed<const N: usize>(&self, args: ArgRange) -> Option<[TermId; N]> {
if args.len() == N {
let mut terms = self.get_args(args);
let arr = std::array::from_fn(|_| terms.next().unwrap());
debug_assert!(terms.next().is_none(), "bug: we checked the length above");
Some(arr)
} else {
None
}
}
#[inline(always)]
pub fn get_term(&self, term_id: TermId) -> Term {
self.terms[term_id.0]
}
pub fn checkpoint(&self) -> Checkpoint {
Checkpoint {
terms: self.terms.len(),
args: self.args.len(),
}
}
pub fn release(&mut self, checkpoint: &Checkpoint) {
debug_assert!(checkpoint.args <= self.args.len() && checkpoint.terms <= self.terms.len());
self.args.truncate(checkpoint.args);
self.terms.truncate(checkpoint.terms);
}
}
impl Default for TermArena {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct Checkpoint {
terms: usize,
args: usize,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct ArgRange {
start: usize,
end: usize,
}
impl Iterator for ArgRange {
type Item = ArgId;
fn next(&mut self) -> Option<Self::Item> {
let start = self.start;
if start == self.end {
None
} else {
self.start += 1;
Some(ArgId(start))
}
}
#[inline]
fn any<F>(&mut self, mut f: F) -> bool
where
Self: Sized,
F: FnMut(Self::Item) -> bool,
{
(self.start..self.end).any(move |x| f(ArgId(x)))
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
(len, Some(len))
}
}
impl ArgRange {
#[inline]
pub fn len(&self) -> usize {
self.end - self.start
}
#[inline]
pub fn is_empty(&self) -> bool {
self.start == self.end
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Term {
Var(Var),
App(AppTerm),
Int(i64),
Cut,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AppTerm(pub Sym, pub ArgRange);