use std::collections::HashMap;
use bumpalo::Bump;
#[derive(Debug)]
pub struct Unit<'a> {
pub decls: &'a [&'a Decl<'a>],
pub clauses: &'a [&'a Clause<'a>],
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Decl<'a> {
pub atom: &'a Atom<'a>,
pub descr: &'a [&'a Atom<'a>],
pub bounds: Option<&'a [&'a BoundDecl<'a>]>,
pub constraints: Option<&'a Constraints<'a>>,
}
#[derive(Debug, PartialEq)]
pub struct BoundDecl<'a> {
pub base_terms: &'a [&'a BaseTerm<'a>],
}
#[derive(Debug, Clone, PartialEq)]
pub struct Constraints<'a> {
pub consequences: &'a [&'a Atom<'a>],
pub alternatives: &'a [&'a [&'a Atom<'a>]],
}
#[derive(Debug)]
pub struct Clause<'a> {
pub head: &'a Atom<'a>,
pub premises: &'a [&'a Term<'a>],
pub transform: &'a [&'a TransformStmt<'a>],
}
#[derive(Debug)]
pub struct TransformStmt<'a> {
pub var: Option<&'a str>,
pub app: &'a BaseTerm<'a>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Term<'a> {
Atom(&'a Atom<'a>),
NegAtom(&'a Atom<'a>),
Eq(&'a BaseTerm<'a>, &'a BaseTerm<'a>),
Ineq(&'a BaseTerm<'a>, &'a BaseTerm<'a>),
}
impl<'a> Term<'a> {
pub fn apply_subst<'b>(
&'a self,
bump: &'b Bump,
subst: &HashMap<&'a str, &'a BaseTerm<'a>>,
) -> &'b Term<'b> {
&*bump.alloc(match self {
Term::Atom(atom) => Term::Atom(atom.apply_subst(bump, subst)),
Term::NegAtom(atom) => Term::NegAtom(atom.apply_subst(bump, subst)),
Term::Eq(left, right) => Term::Eq(
left.apply_subst(bump, subst),
right.apply_subst(bump, subst),
),
Term::Ineq(left, right) => Term::Ineq(
left.apply_subst(bump, subst),
right.apply_subst(bump, subst),
),
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BaseTerm<'a> {
Const(Const<'a>),
Variable(&'a str),
ApplyFn(FunctionSym<'a>, &'a [&'a BaseTerm<'a>]),
}
impl<'a> BaseTerm<'a> {
pub fn apply_subst<'b>(
&'a self,
bump: &'b Bump,
subst: &HashMap<&'a str, &'a BaseTerm<'a>>,
) -> &'b BaseTerm<'b> {
match self {
BaseTerm::Const(_) => copy_base_term(bump, self),
BaseTerm::Variable(v) => subst
.get(v)
.map_or(copy_base_term(bump, self), |b| copy_base_term(bump, b)),
BaseTerm::ApplyFn(fun, args) => {
let args: Vec<&'b BaseTerm<'b>> = args
.iter()
.map(|arg| arg.apply_subst(bump, subst))
.collect();
copy_base_term(bump, &BaseTerm::ApplyFn(*fun, &args))
}
}
}
}
impl<'a> std::fmt::Display for BaseTerm<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BaseTerm::Const(c) => write!(f, "{c}"),
BaseTerm::Variable(v) => write!(f, "{v}"),
BaseTerm::ApplyFn(FunctionSym { name: n, .. }, args) => write!(
f,
"{n}({})",
args.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(",")
),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Const<'a> {
Name(&'a str),
Bool(bool),
Number(i64),
Float(f64),
String(&'a str),
Bytes(&'a [u8]),
List(&'a [&'a Const<'a>]),
Map {
keys: &'a [&'a Const<'a>],
values: &'a [&'a Const<'a>],
},
Struct {
fields: &'a [&'a str],
values: &'a [&'a Const<'a>],
},
}
impl<'a> Eq for Const<'a> {}
impl<'a> std::fmt::Display for Const<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
Const::Name(v) => write!(f, "{v}"),
Const::Bool(v) => write!(f, "{v}"),
Const::Number(v) => write!(f, "{v}"),
Const::Float(v) => write!(f, "{v}"),
Const::String(v) => write!(f, "{v}"),
Const::Bytes(v) => write!(f, "{:?}", v),
Const::List(v) => write!(
f,
"[{}]",
v.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", ")
),
Const::Map { keys: _, values: _ } => write!(f, "{{...}}"),
Const::Struct {
fields: _,
values: _,
} => write!(f, "{{...}}"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PredicateSym<'a> {
pub name: &'a str,
pub arity: Option<u8>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FunctionSym<'a> {
pub name: &'a str,
pub arity: Option<u8>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Atom<'a> {
pub sym: PredicateSym<'a>,
pub args: &'a [&'a BaseTerm<'a>],
}
impl<'a> Atom<'a> {
pub fn matches(&'a self, query_args: &[&BaseTerm]) -> bool {
for (fact_arg, query_arg) in self.args.iter().zip(query_args.iter()) {
if let BaseTerm::Const(_) = query_arg {
if fact_arg != query_arg {
return false;
}
}
}
true
}
pub fn apply_subst<'b>(
&'a self,
bump: &'b Bump,
subst: &HashMap<&'a str, &'a BaseTerm<'a>>,
) -> &'b Atom<'b> {
let args: Vec<&'b BaseTerm<'b>> = self
.args
.iter()
.map(|arg| arg.apply_subst(bump, subst))
.collect();
let args = &*bump.alloc_slice_copy(&args);
bump.alloc(Atom {
sym: copy_predicate_sym(bump, self.sym),
args,
})
}
}
impl<'a> std::fmt::Display for Atom<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}(", self.sym.name)?;
for arg in self.args {
write!(f, "{arg}")?;
}
write!(f, ")")
}
}
pub fn copy_predicate_sym<'dest>(bump: &'dest Bump, p: PredicateSym) -> PredicateSym<'dest> {
PredicateSym {
name: bump.alloc_str(p.name),
arity: p.arity,
}
}
pub fn copy_atom<'dest, 'src>(bump: &'dest Bump, atom: &'src Atom<'src>) -> &'dest Atom<'dest> {
let args: Vec<_> = atom
.args
.iter()
.map(|arg| copy_base_term(bump, arg))
.collect();
let args = &*bump.alloc_slice_copy(&args);
bump.alloc(Atom {
sym: copy_predicate_sym(bump, atom.sym),
args,
})
}
pub fn copy_base_term<'dest, 'src>(
bump: &'dest Bump,
b: &'src BaseTerm<'src>,
) -> &'dest BaseTerm<'dest> {
match b {
BaseTerm::Const(c) =>
{
bump.alloc(BaseTerm::Const(*copy_const(bump, c)))
}
BaseTerm::Variable(s) => bump.alloc(BaseTerm::Variable(bump.alloc_str(s))),
BaseTerm::ApplyFn(fun, args) => {
let fun = FunctionSym {
name: bump.alloc_str(fun.name),
arity: fun.arity,
};
let args: Vec<_> = args.iter().map(|a| copy_base_term(bump, a)).collect();
let args = bump.alloc_slice_copy(&args);
bump.alloc(BaseTerm::ApplyFn(fun, args))
}
}
}
pub fn copy_const<'dest, 'src>(bump: &'dest Bump, c: &'src Const<'src>) -> &'dest Const<'dest> {
match c {
Const::Name(name) => {
let name = &*bump.alloc_str(name);
bump.alloc(Const::Name(name))
}
Const::Bool(b) => bump.alloc(Const::Bool(*b)),
Const::Number(n) => bump.alloc(Const::Number(*n)),
Const::Float(f) => bump.alloc(Const::Float(*f)),
Const::String(s) => {
let s = &*bump.alloc_str(s);
bump.alloc(Const::String(s))
}
Const::Bytes(b) => {
let b = &*bump.alloc_slice_copy(b);
bump.alloc(Const::Bytes(b))
}
Const::List(cs) => {
let cs: Vec<_> = cs.iter().map(|c| copy_const(bump, c)).collect();
let cs = &*bump.alloc_slice_copy(&cs);
bump.alloc(Const::List(cs))
}
Const::Map { keys, values } => {
let keys: Vec<_> = keys.iter().map(|c| copy_const(bump, c)).collect();
let keys = &*bump.alloc_slice_copy(&keys);
let values: Vec<_> = values.iter().map(|c| copy_const(bump, c)).collect();
let values = &*bump.alloc_slice_copy(&values);
bump.alloc(Const::Map { keys, values })
}
Const::Struct { fields, values } => {
let fields: Vec<_> = fields.iter().map(|s| &*bump.alloc_str(s)).collect();
let fields = &*bump.alloc_slice_copy(&fields);
let values: Vec<_> = values.iter().map(|c| copy_const(bump, c)).collect();
let values = &*bump.alloc_slice_copy(&values);
bump.alloc(Const::Struct { fields, values })
}
}
}
pub fn copy_transform<'dest, 'src>(
bump: &'dest Bump,
stmt: &'src TransformStmt<'src>,
) -> &'dest TransformStmt<'dest> {
let TransformStmt { var, app } = stmt;
let var = var.map(|s| &*bump.alloc_str(s));
let app = copy_base_term(bump, app);
bump.alloc(TransformStmt { var, app })
}
pub fn copy_clause<'dest, 'src>(
bump: &'dest Bump,
clause: &'src Clause<'src>,
) -> &'dest Clause<'dest> {
let Clause {
head,
premises,
transform,
} = clause;
let premises: Vec<_> = premises.iter().map(|x| copy_term(bump, x)).collect();
let transform: Vec<_> = transform.iter().map(|x| copy_transform(bump, x)).collect();
bump.alloc(Clause {
head: copy_atom(bump, head),
premises: &*bump.alloc_slice_copy(&premises),
transform: &*bump.alloc_slice_copy(&transform),
})
}
fn copy_term<'dest, 'src>(bump: &'dest Bump, term: &'src Term<'src>) -> &'dest Term<'dest> {
match term {
Term::Atom(atom) => {
let atom = copy_atom(bump, atom);
bump.alloc(Term::Atom(atom))
}
Term::NegAtom(atom) => {
let atom = copy_atom(bump, atom);
bump.alloc(Term::NegAtom(atom))
}
Term::Eq(left, right) => {
let left = copy_base_term(bump, left);
let right = copy_base_term(bump, right);
bump.alloc(Term::Eq(left, right))
}
Term::Ineq(left, right) => {
let left = copy_base_term(bump, left);
let right = copy_base_term(bump, right);
bump.alloc(Term::Ineq(left, right))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bumpalo::Bump;
#[test]
fn it_works() {
let bump = Bump::new();
let foo = &*bump.alloc(BaseTerm::Const(Const::Name("/foo")));
let bar = bump.alloc(PredicateSym {
name: "bar",
arity: Some(1),
});
let bar_args = bump.alloc_slice_copy(&[foo]);
let head = bump.alloc(Atom {
sym: *bar,
args: &*bar_args,
});
assert_eq!("bar(/foo)", head.to_string());
}
}