#![forbid(unsafe_code)]
#![deny(missing_docs)]
extern crate proc_macro;
mod parse;
mod strata;
use proc_macro::TokenStream;
use proc_macro_error::{abort, proc_macro_error};
use quote::{format_ident, quote, quote_spanned};
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Display, Formatter};
use syn::{parse_macro_input, Expr, Ident, Type};
use parse::{Clause, Fact, Program, Relation, Rule};
use strata::Strata;
#[proc_macro]
#[proc_macro_error]
pub fn crepe(input: TokenStream) -> TokenStream {
let program = parse_macro_input!(input as Program);
let context = Context::new(program);
let struct_decls = make_struct_decls(&context);
let runtime_decl = make_runtime_decl(&context);
let runtime_impl = make_runtime_impl(&context);
let expanded = quote! {
#struct_decls
#runtime_decl
#runtime_impl
};
expanded.into()
}
type QuoteWrapper = dyn Fn(proc_macro2::TokenStream) -> proc_macro2::TokenStream;
struct Context {
rels_input: HashMap<String, Relation>,
rels_output: HashMap<String, Relation>,
output_order: Vec<Ident>,
rels_intermediate: HashMap<String, Relation>,
rules: Vec<Rule>,
strata: Strata,
}
impl Context {
fn new(program: Program) -> Self {
let mut rels_input = HashMap::new();
let mut rels_output = HashMap::new();
let mut rels_intermediate = HashMap::new();
let mut rel_names = HashSet::new();
let mut output_order = Vec::new();
program.relations.into_iter().for_each(|relation| {
let name = relation.name.to_string();
if !rel_names.insert(relation.name.clone()) {
abort!(relation.name.span(), "Duplicate relation name: {}", name);
}
if let Some(ref attr) = relation.attribute {
match attr.to_string().as_ref() {
"input" => {
rels_input.insert(name, relation);
}
"output" => {
output_order.push(relation.name.clone());
rels_output.insert(name, relation);
}
s => abort!(
attr.span(),
"Invalid attribute @{}, expected '@input' or '@output'",
s
),
}
} else {
rels_intermediate.insert(name, relation);
}
});
let mut dependencies = HashSet::new();
let check = |fact: &Fact| {
let name = fact.relation.to_string();
if !rel_names.contains(&fact.relation) {
abort!(
fact.relation.span(),
"Relation name '{}' was not found. Did you misspell it?",
name
);
}
let rel = rels_input
.get(&name)
.or_else(|| rels_output.get(&name))
.or_else(|| rels_intermediate.get(&name))
.expect("relation should exist");
if rel.fields.len() != fact.fields.len() {
abort!(
fact.relation.span(),
"Relation '{}' was declared with arity {}, but constructed with arity {} here.",
name,
rel.fields.len(),
fact.fields.len(),
);
}
};
program.rules.iter().for_each(|rule| {
check(&rule.goal);
if rels_input.get(&rule.goal.relation.to_string()).is_some() {
abort!(
rule.goal.relation.span(),
"Relations marked as @input cannot be derived from a rule."
)
}
if rule.goal.fields.iter().any(Option::is_none) {
abort!(
rule.goal.relation.span(),
"Cannot have _ in goal atom of rule."
)
}
rule.clauses.iter().for_each(|clause| {
if let Clause::Fact(fact) = clause {
check(&fact);
dependencies.insert((&rule.goal.relation, &fact.relation));
}
});
});
let strata = Strata::new(rel_names, dependencies);
for rule in &program.rules {
let goal_stratum = strata.find_relation(&rule.goal.relation);
for clause in &rule.clauses {
if let Clause::Fact(fact) = clause {
if fact.negate.is_some() {
let fact_stratum = strata.find_relation(&fact.relation);
if goal_stratum == fact_stratum {
abort!(
fact.relation.span(),
"Negation of relation '{}' creates a dependency cycle \
and cannot be stratified.",
fact.relation
);
}
assert!(goal_stratum > fact_stratum);
}
}
}
}
let rules = program.rules;
Self {
rels_input,
rels_output,
output_order,
rels_intermediate,
rules,
strata,
}
}
fn get_relation(&self, name: &str) -> Option<&Relation> {
self.rels_input
.get(name)
.or_else(|| self.rels_intermediate.get(name))
.or_else(|| self.rels_output.get(name))
}
fn all_relations(&self) -> impl Iterator<Item = &Relation> {
self.rels_input
.values()
.chain(self.rels_intermediate.values())
.chain(self.rels_output.values())
}
}
#[derive(Eq, PartialEq, Hash, Copy, Clone)]
enum IndexMode {
Bound,
Free,
}
#[derive(Eq, PartialEq, Hash, Clone)]
struct Index {
name: Ident,
mode: Vec<IndexMode>,
}
impl Index {
fn to_ident(&self) -> Ident {
Ident::new(&self.to_string(), self.name.span())
}
fn key_type<'a>(&self, context: &'a Context) -> Vec<&'a Type> {
let rel = context
.get_relation(&self.name.to_string())
.expect("could not find relation of index name");
self.mode
.iter()
.zip(rel.fields.iter())
.filter_map(|(mode, field)| match mode {
IndexMode::Bound => Some(&field.ty),
IndexMode::Free => None,
})
.collect()
}
fn bound_pos(&self) -> Vec<syn::Index> {
self.mode
.iter()
.enumerate()
.filter_map(|(i, mode)| match mode {
IndexMode::Bound => Some(syn::Index::from(i)),
IndexMode::Free => None,
})
.collect()
}
}
impl Display for Index {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mode: String = self
.mode
.iter()
.map(|mode| match mode {
IndexMode::Bound => 'b',
IndexMode::Free => 'f',
})
.collect();
write!(f, "__{}_index{}", to_lowercase(&self.name), mode)
}
}
fn make_struct_decls(context: &Context) -> proc_macro2::TokenStream {
context
.all_relations()
.map(|relation| {
let attrs = &relation.attrs;
let struct_token = &relation.struct_token;
let vis = &relation.visibility;
let name = &relation.name;
let semi_token = &relation.semi_token;
let fields = &relation.fields;
quote_spanned! {name.span()=>
#[derive(
::core::marker::Copy,
::core::clone::Clone,
::core::cmp::Eq,
::core::cmp::PartialEq,
::core::hash::Hash,
)]
#(#attrs)*
#vis #struct_token #name(#fields)#semi_token
}
})
.collect()
}
fn make_runtime_decl(context: &Context) -> proc_macro2::TokenStream {
let fields: proc_macro2::TokenStream = context
.rels_input
.values()
.map(|relation| {
let name = &relation.name;
let lowercase_name = to_lowercase(name);
quote! {
#lowercase_name: ::std::vec::Vec<#name>,
}
})
.collect();
quote! {
#[derive(::core::default::Default)]
struct Crepe {
#fields
}
}
}
fn make_runtime_impl(context: &Context) -> proc_macro2::TokenStream {
let builders = make_extend(&context);
let run = make_run(&context);
quote! {
impl Crepe {
fn new() -> Self {
::core::default::Default::default()
}
#run
}
#builders
}
}
fn make_extend(context: &Context) -> proc_macro2::TokenStream {
context
.rels_input
.values()
.map(|relation| {
let name = &relation.name;
let lower = to_lowercase(name);
quote! {
impl ::core::iter::Extend<#name> for Crepe {
fn extend<T>(&mut self, iter: T)
where
T: ::core::iter::IntoIterator<Item = #name>,
{
self.#lower.extend(iter);
}
}
impl<'a> ::core::iter::Extend<&'a #name> for Crepe {
fn extend<T>(&mut self, iter: T)
where
T: ::core::iter::IntoIterator<Item = &'a #name>,
{
self.extend(iter.into_iter().copied());
}
}
}
})
.collect()
}
fn make_run(context: &Context) -> proc_macro2::TokenStream {
let mut indices: HashSet<Index> = HashSet::new();
let main_loops = {
let mut loop_wrappers = Vec::new();
for stratum in context.strata.iter() {
loop_wrappers.push(make_stratum(context, stratum, &mut indices));
}
loop_wrappers
.iter()
.zip(context.strata.iter())
.map(|(f, stratum)| f(make_updates(context, stratum, &indices)))
.collect::<proc_macro2::TokenStream>()
};
let initialize = {
let init_rels = context.all_relations().map(|rel| {
let name = &rel.name;
let lower = to_lowercase(name);
let var = format_ident!("__{}", lower);
let var_update = format_ident!("__{}_update", lower);
quote! {
let mut #var: ::std::collections::HashSet<#name> =
::std::collections::HashSet::new();
let mut #var_update: ::std::collections::HashSet<#name> =
::std::collections::HashSet::new();
}
});
let init_indices = indices.iter().map(|index| {
let rel = context
.get_relation(&index.name.to_string())
.expect("index relation should be found in context");
let rel_name = &rel.name;
let index_name = index.to_ident();
let key_type = index.key_type(context);
quote! {
let mut #index_name:
::std::collections::HashMap<(#(#key_type,)*), ::std::vec::Vec<#rel_name>> =
::std::collections::HashMap::new();
}
});
let load_inputs = context.rels_input.values().map(|rel| {
let lower = to_lowercase(&rel.name);
let var_update = format_ident!("__{}_update", lower);
quote! {
#var_update.extend(self.#lower);
}
});
init_rels
.chain(init_indices)
.chain(load_inputs)
.collect::<proc_macro2::TokenStream>()
};
let output = {
let output_fields = context.output_order.iter().map(|name| {
let lower = to_lowercase(name);
format_ident!("__{}", lower)
});
quote! {
(#(#output_fields,)*)
}
};
let output_ty = make_output_ty(&context);
quote! {
fn run(self) -> #output_ty {
#initialize
#main_loops
#output
}
}
}
fn make_stratum(
context: &Context,
stratum: &[Ident],
indices: &mut HashSet<Index>,
) -> Box<QuoteWrapper> {
let stratum: HashSet<_> = stratum.iter().collect();
let current_rels: Vec<_> = stratum
.iter()
.map(|name| {
context
.get_relation(&name.to_string())
.expect("cannot find relation from stratum")
})
.collect();
let empty_cond: proc_macro2::TokenStream = current_rels
.iter()
.map(|rel| {
let lower = to_lowercase(&rel.name);
let rel_update = format_ident!("__{}_update", lower);
quote! {
#rel_update.is_empty() &&
}
})
.chain(std::iter::once(quote! {true}))
.collect();
let new_decls: proc_macro2::TokenStream = current_rels
.iter()
.map(|rel| {
let name = &rel.name;
let lower = to_lowercase(name);
let rel_new = format_ident!("__{}_new", lower);
quote! {
let mut #rel_new: ::std::collections::HashSet<#name> =
::std::collections::HashSet::new();
}
})
.collect();
let rules: proc_macro2::TokenStream = context
.rules
.iter()
.filter(|rule| stratum.contains(&rule.goal.relation))
.map(|rule| make_rule(rule, &stratum, indices))
.collect();
let set_update_to_new: proc_macro2::TokenStream = current_rels
.iter()
.map(|rel| {
let lower = to_lowercase(&rel.name);
let rel_update = format_ident!("__{}_update", lower);
let rel_new = format_ident!("__{}_new", lower);
quote! {
#rel_update = #rel_new;
}
})
.collect();
Box::new(move |updates| {
quote! {
let mut __crepe_first_iteration = true;
while __crepe_first_iteration || !(#empty_cond) {
#updates
#new_decls
#rules
#set_update_to_new
__crepe_first_iteration = false;
}
}
})
}
fn make_updates(
context: &Context,
stratum: &[Ident],
indices: &HashSet<Index>,
) -> proc_macro2::TokenStream {
let rel_updates = stratum.iter().map(|name| {
let lower = to_lowercase(name);
let rel = format_ident!("__{}", lower);
let rel_update = format_ident!("__{}_update", lower);
quote! {
#rel.extend(&#rel_update);
}
});
let index_updates = indices.iter().filter_map(|index| {
if !stratum.contains(&index.name) {
return None;
}
let rel = context
.get_relation(&index.name.to_string())
.expect("index relation should be found in context");
let rel_name = &rel.name;
let rel_update = format_ident!("__{}_update", to_lowercase(rel_name));
let index_name = index.to_ident();
let index_name_update = format_ident!("{}_update", index_name);
let key_type = index.key_type(context);
let bound_pos = index.bound_pos();
Some(quote! {
let mut #index_name_update:
::std::collections::HashMap<(#(#key_type,)*), ::std::vec::Vec<#rel_name>> =
::std::collections::HashMap::new();
for &__crepe_var in #rel_update.iter() {
#index_name
.entry((#(__crepe_var.#bound_pos,)*))
.or_default()
.push(__crepe_var);
#index_name_update
.entry((#(__crepe_var.#bound_pos,)*))
.or_default()
.push(__crepe_var);
}
})
});
rel_updates.chain(index_updates).collect()
}
fn make_rule(
rule: &Rule,
stratum: &HashSet<&Ident>,
indices: &mut HashSet<Index>,
) -> proc_macro2::TokenStream {
let goal = {
let relation = &rule.goal.relation;
let fields = &rule.goal.fields;
let name = format_ident!("__{}", to_lowercase(relation));
let name_new = format_ident!("__{}_new", to_lowercase(relation));
quote! {
let __crepe_goal = #relation(#fields);
if !#name.contains(&__crepe_goal) {
#name_new.insert(__crepe_goal);
}
}
};
let fact_positions: Vec<_> = rule
.clauses
.iter()
.enumerate()
.filter_map(|(i, clause)| match clause {
Clause::Fact(fact) => {
if stratum.contains(&fact.relation) {
Some(i)
} else {
None
}
}
_ => None,
})
.collect();
if fact_positions.is_empty() {
let mut datalog_vars: HashSet<String> = HashSet::new();
let fragments: Vec<_> = rule
.clauses
.iter()
.cloned()
.map(|clause| make_clause(clause, false, &mut datalog_vars, indices))
.collect();
let eval_loop = fragments.into_iter().rev().fold(goal, |x, f| f(x));
quote! {
if __crepe_first_iteration {
#eval_loop
}
}
} else {
let mut variants = Vec::new();
for update_position in fact_positions {
let mut datalog_vars: HashSet<String> = HashSet::new();
let fragments: Vec<_> = rule
.clauses
.iter()
.cloned()
.enumerate()
.map(|(i, clause)| {
make_clause(clause, update_position == i, &mut datalog_vars, indices)
})
.collect();
let eval_loop = fragments.into_iter().rev().fold(goal.clone(), |x, f| f(x));
variants.push(eval_loop);
}
variants.into_iter().collect()
}
}
fn make_clause(
clause: Clause,
only_update: bool,
datalog_vars: &mut HashSet<String>,
indices: &mut HashSet<Index>,
) -> Box<QuoteWrapper> {
match clause {
Clause::Fact(fact) => {
let name = &fact.relation;
if fact.negate.is_some() {
assert!(!only_update);
let to_mode = |f: &Option<_>| {
f.as_ref()
.map(|_| IndexMode::Bound)
.unwrap_or(IndexMode::Free)
};
let index = Index {
name: name.clone(),
mode: fact.fields.iter().map(to_mode).collect(),
};
let index_name = index.to_ident();
indices.insert(index);
let bound_fields: Vec<_> = fact.fields.iter().flatten().cloned().collect();
return Box::new(move |body| {
quote_spanned! {fact.relation.span()=>
if !#index_name.contains_key(&(#(#bound_fields,)*)) {
#body
}
}
});
}
let mut setters = Vec::new();
let mut index_mode = Vec::new();
for (i, field) in fact.fields.iter().enumerate() {
let idx = syn::Index::from(i);
if field.is_none() {
index_mode.push(IndexMode::Free);
} else if let Some(var) = is_datalog_var(field.as_ref().unwrap()) {
let var_name = var.to_string();
if datalog_vars.contains(&var_name) {
index_mode.push(IndexMode::Bound);
} else {
index_mode.push(IndexMode::Free);
datalog_vars.insert(var_name);
setters.push(quote! {
let #field = __crepe_var.#idx;
});
}
} else {
index_mode.push(IndexMode::Bound);
}
}
let setters: proc_macro2::TokenStream = setters.into_iter().collect();
if !index_mode.contains(&IndexMode::Bound) {
let mut rel = format_ident!("__{}", &to_lowercase(name));
if only_update {
rel = format_ident!("{}_update", rel);
}
Box::new(move |body| {
quote_spanned! {fact.relation.span()=>
for &__crepe_var in #rel.iter() {
#setters
#body
}
}
})
} else {
let bound_fields: Vec<_> = index_mode
.iter()
.zip(fact.fields.iter())
.filter_map(|(mode, field)| match mode {
IndexMode::Bound => Some(field.clone()),
IndexMode::Free => None,
})
.collect();
let index = Index {
name: name.clone(),
mode: index_mode,
};
let mut index_name = Ident::new(&index.to_string(), name.span());
if only_update {
index_name = format_ident!("{}_update", index_name);
}
indices.insert(index);
Box::new(move |body| {
quote_spanned! {fact.relation.span()=>
if let Some(__crepe_iter) = #index_name.get(&(#(#bound_fields,)*)) {
for &__crepe_var in __crepe_iter.iter() {
#setters
#body
}
}
}
})
}
}
Clause::Expr(expr) => {
assert!(!only_update);
Box::new(move |body| {
quote! {
#[allow(unused_parens)]
if #expr { #body }
}
})
}
Clause::Let(guard) => {
assert!(!only_update);
Box::new(move |body| {
quote! {
#[allow(irrefutable_let_patterns)]
if #guard { #body }
}
})
}
}
}
fn make_output_ty(context: &Context) -> proc_macro2::TokenStream {
let fields = &context.output_order;
quote! {
(#(::std::collections::HashSet<#fields>,)*)
}
}
fn is_datalog_var(expr: &Expr) -> Option<Ident> {
use syn::{ExprPath, Path, PathArguments};
match expr {
Expr::Path(ExprPath {
attrs,
qself: None,
path:
Path {
leading_colon: None,
segments,
},
}) => {
if attrs.is_empty() && segments.len() == 1 {
let segment = segments.iter().next()?;
if let PathArguments::None = segment.arguments {
let ident = segment.ident.clone();
match ident.to_string().chars().next() {
Some('a'..='z') | Some('_') => return Some(ident),
_ => (),
}
}
}
None
}
_ => None,
}
}
fn to_lowercase(name: &Ident) -> Ident {
let s = name.to_string().to_lowercase();
Ident::new(&s, name.span())
}