use proc_macro2::TokenStream;
use quote::ToTokens;
use syn::{FnArg, ImplItem, ItemImpl, ItemTrait, Pat, TraitItem, TraitItemFn};
use crate::implementation::ContractType;
fn contract_method_impl_name(name: &str) -> String {
format!("__contracts_impl_{}", name)
}
pub(crate) fn contract_trait_item_trait(_attrs: TokenStream, mut trait_: ItemTrait) -> TokenStream {
fn create_method_rename(method: &TraitItemFn) -> TraitItemFn {
let mut m = method.clone();
{
let name = m.sig.ident.to_string();
let new_name = contract_method_impl_name(&name);
let mut new_attrs = vec![];
new_attrs.push(syn::parse_quote!(#[doc(hidden)]));
new_attrs.push(syn::parse_quote!(#[doc = "This is an internal function that is not meant to be used directly!"]));
new_attrs.push(syn::parse_quote!(#[doc = "See the documentation of the `#[contract_trait]` attribute."]));
new_attrs.extend(
m.attrs
.iter()
.filter(|attr| {
let name = attr.path().segments.last().unwrap().ident.to_string();
ContractType::contract_type_and_mode(&name).is_none()
})
.cloned(),
);
m.attrs = new_attrs;
m.sig.ident = syn::Ident::new(&new_name, m.sig.ident.span());
}
m
}
fn create_method_wrapper(method: &TraitItemFn) -> TraitItemFn {
struct ArgInfo {
call_toks: proc_macro2::TokenStream,
}
fn arg_pat_info(pat: &Pat) -> ArgInfo {
match pat {
Pat::Ident(ident) => {
let toks = quote::quote! {
#ident
};
ArgInfo { call_toks: toks }
}
Pat::Tuple(tup) => {
let infos = tup.elems.iter().map(arg_pat_info);
let toks = {
let mut toks = proc_macro2::TokenStream::new();
for info in infos {
toks.extend(info.call_toks);
toks.extend(quote::quote!(,));
}
toks
};
ArgInfo {
call_toks: quote::quote!((#toks)),
}
}
Pat::TupleStruct(_tup) => unimplemented!(),
p => panic!("Unsupported pattern type: {:?}", p),
}
}
let mut m = method.clone();
let argument_data = m
.sig
.inputs
.clone()
.into_iter()
.map(|t: FnArg| match &t {
FnArg::Receiver(_) => quote::quote!(self),
FnArg::Typed(p) => {
let info = arg_pat_info(&p.pat);
info.call_toks
}
})
.collect::<Vec<_>>();
let arguments = {
let mut toks = proc_macro2::TokenStream::new();
for arg in argument_data {
toks.extend(arg);
toks.extend(quote::quote!(,));
}
toks
};
let body: TokenStream = {
let name = contract_method_impl_name(&m.sig.ident.to_string());
let name = syn::Ident::new(&name, m.sig.ident.span());
quote::quote! {
{
Self::#name(#arguments)
}
}
};
let mut attrs = vec![];
attrs.extend(
m.attrs
.iter()
.filter(|a| {
let name = a.path().segments.last().unwrap().ident.to_string();
if name == "doc" {
return true;
}
ContractType::contract_type_and_mode(&name).is_some()
})
.cloned(),
);
attrs.push(syn::parse_quote!(#[inline(always)]));
m.attrs = attrs;
{
let block: syn::Block = syn::parse2(body).unwrap();
m.default = Some(block);
m.semi_token = None;
}
m
}
let funcs = trait_
.items
.iter()
.filter_map(|item| {
if let TraitItem::Fn(m) = item {
let rename = create_method_rename(m);
let wrapper = create_method_wrapper(m);
Some(vec![TraitItem::Fn(rename), TraitItem::Fn(wrapper)])
} else {
None
}
})
.flatten()
.collect::<Vec<_>>();
trait_
.items
.retain(|item| !matches!(item, TraitItem::Fn(_)));
trait_.items.extend(funcs);
trait_.into_token_stream()
}
pub(crate) fn contract_trait_item_impl(_attrs: TokenStream, impl_: ItemImpl) -> TokenStream {
let new_impl = {
let mut impl_: ItemImpl = impl_;
impl_.items.iter_mut().for_each(|it| {
if let ImplItem::Fn(method) = it {
let new_name = contract_method_impl_name(&method.sig.ident.to_string());
let new_ident = syn::Ident::new(&new_name, method.sig.ident.span());
method.sig.ident = new_ident;
}
});
impl_
};
new_impl.to_token_stream()
}
#[cfg(test)]
mod tests {
#[test]
fn attributes_stay_on_trait_def() {
let code = syn::parse_quote! {
trait Random {
#[aaa]
#[ensures((min..max).contains(ret))]
fn random_number(min: u8, max: u8) -> u8;
}
};
let expected = quote::quote! {
trait Random {
#[doc(hidden)]
#[doc = "This is an internal function that is not meant to be used directly!"]
#[doc = "See the documentation of the `#[contract_trait]` attribute."]
#[aaa]
fn __contracts_impl_random_number(min: u8, max: u8) -> u8;
#[ensures((min..max).contains(ret))]
#[inline(always)]
fn random_number(min: u8, max: u8) -> u8 {
Self::__contracts_impl_random_number(min, max,)
}
}
};
let generated = super::contract_trait_item_trait(Default::default(), code);
assert_eq!(generated.to_string(), expected.to_string());
}
#[test]
fn attributes_stay_on_trait_impl() {
let code = syn::parse_quote! {
impl Random for AlwaysMin {
#[no_panic]
fn random_number(min: u8, max: u8) -> u8 {
min
}
}
};
let expected = quote::quote! {
impl Random for AlwaysMin {
#[no_panic]
fn __contracts_impl_random_number(min: u8, max: u8) -> u8 {
min
}
}
};
let generated = super::contract_trait_item_impl(Default::default(), code);
assert_eq!(generated.to_string(), expected.to_string());
}
}