extern crate proc_macro;
use proc_macro::TokenStream;
use syn::parse::Parser;
use syn::punctuated::Punctuated;
use syn::{
Block, Expr, ExprLit, FnArg, ImplItem, ImplItemMethod, Item, ItemFn, ItemImpl, Lit, ReturnType,
};
use syn::{ItemTrait, Pat, Token, TraitItem, TraitItemMethod};
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum ContractMode {
Always,
Disabled,
Debug,
Test,
LogOnly,
}
impl ContractMode {
fn name(self) -> Option<&'static str> {
match self {
ContractMode::Always => Some(""),
ContractMode::Disabled => None,
ContractMode::Debug => Some("debug_"),
ContractMode::Test => Some("test_"),
ContractMode::LogOnly => None,
}
}
}
fn final_mode(mode: ContractMode) -> ContractMode {
if mode == ContractMode::Disabled || mode == ContractMode::Test {
return mode;
}
if cfg!(feature = "disable_contracts") {
ContractMode::Disabled
} else if cfg!(feature = "override_debug") {
if mode == ContractMode::LogOnly {
mode
} else {
ContractMode::Debug
}
} else if cfg!(feature = "override_log") {
ContractMode::LogOnly
} else {
mode
}
}
#[proc_macro_attribute]
pub fn pre(attr: TokenStream, toks: TokenStream) -> TokenStream {
let mode = final_mode(ContractMode::Always);
impl_pre(mode, attr, toks)
}
#[proc_macro_attribute]
pub fn debug_pre(attr: TokenStream, toks: TokenStream) -> TokenStream {
let mode = final_mode(ContractMode::Debug);
impl_pre(mode, attr, toks)
}
#[proc_macro_attribute]
pub fn test_pre(attr: TokenStream, toks: TokenStream) -> TokenStream {
let mode = final_mode(ContractMode::Test);
impl_pre(mode, attr, toks)
}
fn impl_pre(mode: ContractMode, attr: TokenStream, toks: TokenStream) -> TokenStream {
let (conds, desc) = parse_attributes(attr);
let item: ItemFn = syn::parse_macro_input!(toks as ItemFn);
let fn_name = item.ident.to_string();
let desc = if let Some(desc) = desc {
format!("Pre-condition of {} violated - {:?}", fn_name, desc)
} else {
format!("Pre-condition of {} violated", fn_name)
};
let pre = attributes_to_asserts(mode, conds, desc);
let post = quote::quote! {};
impl_fn_checks(item, pre, post)
}
#[proc_macro_attribute]
pub fn post(attr: TokenStream, toks: TokenStream) -> TokenStream {
let mode = final_mode(ContractMode::Always);
impl_post(mode, attr, toks)
}
#[proc_macro_attribute]
pub fn debug_post(attr: TokenStream, toks: TokenStream) -> TokenStream {
let mode = final_mode(ContractMode::Debug);
impl_post(mode, attr, toks)
}
#[proc_macro_attribute]
pub fn test_post(attr: TokenStream, toks: TokenStream) -> TokenStream {
let mode = final_mode(ContractMode::Test);
impl_post(mode, attr, toks)
}
fn impl_post(mode: ContractMode, attr: TokenStream, toks: TokenStream) -> TokenStream {
let (conds, desc) = parse_attributes(attr);
let item: ItemFn = syn::parse_macro_input!(toks as ItemFn);
let fn_name = item.ident.to_string();
let desc = if let Some(desc) = desc {
format!("Post-condition of {} violated - {:?}", fn_name, desc)
} else {
format!("Post-condition of {} violated", fn_name)
};
let pre = quote::quote! {};
let post = attributes_to_asserts(mode, conds, desc);
impl_fn_checks(item, pre, post)
}
#[proc_macro_attribute]
pub fn invariant(attr: TokenStream, toks: TokenStream) -> TokenStream {
let mode = ContractMode::Always;
impl_invariant(mode, attr, toks)
}
#[proc_macro_attribute]
pub fn debug_invariant(attr: TokenStream, toks: TokenStream) -> TokenStream {
let mode = ContractMode::Debug;
impl_invariant(mode, attr, toks)
}
#[proc_macro_attribute]
pub fn test_invariant(attr: TokenStream, toks: TokenStream) -> TokenStream {
let mode = ContractMode::Test;
impl_invariant(mode, attr, toks)
}
fn impl_invariant(mode: ContractMode, attr: TokenStream, toks: TokenStream) -> TokenStream {
let item: Item = syn::parse_macro_input!(toks as Item);
let name = mode.name().unwrap().to_string() + "invariant";
match item {
Item::Fn(fn_) => impl_invariant_fn(mode, attr, fn_),
Item::Impl(impl_) => impl_impl_invariant(mode, attr, impl_),
_ => unimplemented!(
"The #[{}] attribute only works on functions and impl-blocks.",
name
),
}
}
fn impl_invariant_fn(mode: ContractMode, attr: TokenStream, fn_: ItemFn) -> TokenStream {
let mode = final_mode(mode);
let (conds, desc) = parse_attributes(attr);
let fn_name = fn_.ident.to_string();
let desc = if let Some(desc) = desc {
format!("Invariant of {} violated - {:?}", fn_name, desc)
} else {
format!("Invariant of {} violated", fn_name)
};
let pre = attributes_to_asserts(mode, conds, desc);
let post = pre.clone();
impl_fn_checks(fn_, pre, post)
}
#[proc_macro_attribute]
pub fn contract_trait(attrs: TokenStream, toks: TokenStream) -> TokenStream {
let item: Item = syn::parse_macro_input!(toks);
match item {
Item::Trait(trait_) => contract_trait_item_trait(attrs, trait_),
Item::Impl(impl_) => {
assert!(
impl_.trait_.is_some(),
"#[contract_trait] can only be applied to `trait` and `impl ... for` items"
);
contract_trait_item_impl(attrs, impl_)
}
_ => panic!("#[contract_trait] can only be applied to `trait` and `impl ... for` items"),
}
}
fn contract_method_impl_name(name: &str) -> String {
format!("__contracts_impl_{}", name)
}
fn contract_trait_item_trait(_attrs: TokenStream, mut trait_: ItemTrait) -> TokenStream {
fn create_method_rename(method: &TraitItemMethod) -> TraitItemMethod {
let mut m: TraitItemMethod = (*method).clone();
{
let name = m.sig.ident.to_string();
let new_name = contract_method_impl_name(&name);
m.attrs.clear();
m.sig.ident = syn::Ident::new(&new_name, m.sig.ident.span());
}
m
}
fn create_method_wrapper(method: &TraitItemMethod) -> TraitItemMethod {
struct ArgInfo {
call_toks: proc_macro2::TokenStream,
}
fn arg_pat_info(pat: &Pat) -> ArgInfo {
match pat {
Pat::Ident(ident) => {
let toks = quote::quote! {
#ident
};
ArgInfo { call_toks: toks }
}
Pat::Tuple(tup) => {
let infos = tup.front.iter().map(arg_pat_info);
let toks = {
let mut toks = proc_macro2::TokenStream::new();
for info in infos {
toks.extend(info.call_toks);
toks.extend(quote::quote!(,));
}
toks
};
ArgInfo {
call_toks: quote::quote!((#toks)),
}
}
Pat::TupleStruct(_tup) => unimplemented!(),
p => panic!("Unsupported pattern type: {:?}", p),
}
}
let mut m: TraitItemMethod = (*method).clone();
let argument_data = m
.sig
.decl
.inputs
.clone()
.into_iter()
.map(|t: FnArg| match &t {
FnArg::SelfRef(_) | FnArg::SelfValue(_) => quote::quote!(self),
FnArg::Captured(c) => {
let info = arg_pat_info(&c.pat);
info.call_toks
}
FnArg::Inferred(inf) => unimplemented!("Inferred pattern: {:?}", inf),
FnArg::Ignored(_ty) => {
unimplemented!("Ignored patterns are not allowed in contract trait methods");
}
})
.collect::<Vec<_>>();
let arguments = {
let mut toks = proc_macro2::TokenStream::new();
for arg in argument_data {
toks.extend(arg);
toks.extend(quote::quote!(,));
}
toks
};
let body: TokenStream = {
let name = contract_method_impl_name(&m.sig.ident.to_string());
let name = syn::Ident::new(&name, m.sig.ident.span());
let toks = quote::quote! {
{
Self::#name(#arguments)
}
};
toks.into()
};
{
let block: syn::Block = syn::parse_macro_input::parse(body).unwrap();
m.default = Some(block);
m.semi_token = None;
}
m
}
let funcs = trait_
.items
.iter()
.filter_map(|item| {
if let TraitItem::Method(m) = item {
let rename = create_method_rename(m);
let wrapper = create_method_wrapper(m);
Some(vec![TraitItem::Method(rename), TraitItem::Method(wrapper)])
} else {
None
}
})
.flatten()
.collect::<Vec<_>>();
trait_.items = trait_
.items
.into_iter()
.filter(|item| {
if let TraitItem::Method(_) = item {
false
} else {
true
}
})
.collect();
trait_.items.extend(funcs);
let toks = quote::quote! {
#trait_
};
toks.into()
}
fn contract_trait_item_impl(_attrs: TokenStream, impl_: ItemImpl) -> TokenStream {
let new_impl = {
let mut impl_: ItemImpl = impl_.clone();
impl_.items.iter_mut().for_each(|it| {
if let ImplItem::Method(method) = it {
let new_name = contract_method_impl_name(&method.sig.ident.to_string());
let new_ident = syn::Ident::new(&new_name, method.sig.ident.span());
method.sig.ident = new_ident;
}
});
impl_
};
let toks = quote::quote! {
#new_impl
};
toks.into()
}
fn parse_attributes(attrs: TokenStream) -> (Vec<Expr>, Option<String>) {
let mut conds: Punctuated<Expr, Token![,]> = {
let tokens = attrs;
let parser = Punctuated::<Expr, Token![,]>::parse_separated_nonempty;
let terminated = parser.parse(tokens.clone());
if let Ok(res) = terminated {
res
} else {
let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
parser.parse(tokens).unwrap()
}
};
let desc = conds
.last()
.map(|x| {
let expr = *x.value();
if let Expr::Lit(ExprLit {
lit: Lit::Str(str), ..
}) = expr
{
Some(str.value())
} else {
None
}
})
.unwrap_or(None);
if desc.is_some() {
conds.pop();
}
let exprs = conds.into_iter().map(|e| e).collect();
(exprs, desc)
}
fn attributes_to_asserts(
mode: ContractMode,
exprs: Vec<Expr>,
desc: String,
) -> proc_macro2::TokenStream {
let mut stream = proc_macro2::TokenStream::new();
let generate = |expr: &Expr, desc: &str| {
let format_args = quote::quote! {
concat!(concat!(#desc, ": "), stringify!(#expr))
};
match mode {
ContractMode::Always => {
quote::quote! {
assert!(#expr, #format_args);
}
}
ContractMode::Disabled => {
quote::quote! {}
}
ContractMode::Debug => {
quote::quote! {
debug_assert!(#expr, #format_args);
}
}
ContractMode::Test => {
quote::quote! {
#[cfg(test)]
{
assert!(#expr, #format_args);
}
}
}
ContractMode::LogOnly => {
quote::quote! {
if !(#expr) {
log::error!(#format_args);
}
}
}
}
};
for expr in exprs {
stream.extend(generate(&expr, &desc));
}
stream
}
fn impl_fn_checks(
mut fn_def: ItemFn,
pre: proc_macro2::TokenStream,
post: proc_macro2::TokenStream,
) -> TokenStream {
let block = fn_def.block.clone();
let ret_ty = if let ReturnType::Type(_, ty) = &fn_def.decl.output {
quote::quote! {
#ty
}
} else {
quote::quote! { () }
};
let new_block = quote::quote! {
{
#pre
let ret: #ret_ty = {
#block
};
#post
ret
}
}
.into();
fn_def.block = Box::new(syn::parse_macro_input!(new_block as Block));
let res = quote::quote! {
#fn_def
};
res.into()
}
fn impl_impl_invariant(
mode: ContractMode,
invariant: TokenStream,
mut impl_def: ItemImpl,
) -> TokenStream {
let name = match mode.name() {
Some(n) => n.to_string() + "invariant",
None => {
return quote::quote!( #impl_def ).into();
}
};
let invariant_ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
let invariant: proc_macro2::TokenStream = invariant.into();
fn method_uses_self(method: &ImplItemMethod) -> bool {
let inputs = &method.sig.decl.inputs;
if !inputs.is_empty() {
match inputs[0] {
FnArg::SelfValue(_) | FnArg::SelfRef(_) => true,
_ => false,
}
} else {
false
}
}
for item in &mut impl_def.items {
if let ImplItem::Method(method) = item {
if !method_uses_self(method) {
continue;
}
let method_toks = quote::quote! {
#[#invariant_ident(#invariant)]
#method
}
.into();
let met = syn::parse_macro_input!(method_toks as ImplItemMethod);
*method = met;
}
}
let toks = quote::quote! {
#impl_def
};
toks.into()
}