use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{format_ident, quote};
use syn::{
parse_macro_input, FnArg, GenericArgument, ItemTrait, Lifetime, PathArguments, ReturnType,
TraitItem, Type,
};
#[proc_macro_attribute]
pub fn stub(_: TokenStream, input: TokenStream) -> TokenStream {
let item_trait = parse_macro_input!(input as ItemTrait);
let trait_ident = &item_trait.ident;
let trait_generic_params = &item_trait.generics.params;
let stub_struct_ident = format_ident!("Stub{}", trait_ident);
let mut attrs = vec![];
let mut attrs_init = vec![];
let mut impl_fns = vec![];
let mut stub_fns = vec![];
for trait_item in &item_trait.items {
if let TraitItem::Method(item_method) = trait_item {
let method_inputs = &item_method.sig.inputs;
let mut method_inputs_iter = method_inputs.iter();
match method_inputs_iter.next() {
Some(FnArg::Receiver(_)) => {}
_ => panic!("The trait must be can made into an object"),
}
let mut method_arg_names = vec![];
let mut method_arg_types = vec![];
for arg in method_inputs_iter {
if let FnArg::Typed(arg) = arg {
method_arg_names.push(arg.pat.as_ref());
method_arg_types.push(arg.ty.as_ref());
}
}
let method_ident = &item_method.sig.ident;
let attr_ident = format_ident!("{}_stub", method_ident);
let count_calls_of_fn_ident = format_ident!("count_calls_of_{}", method_ident);
let with_stub_of_fn_ident = format_ident!("with_stub_of_{}", method_ident);
let mut method_output = item_method.sig.output.clone();
let method_output = match method_output {
ReturnType::Default => quote! {},
ReturnType::Type(_, ref mut ty) => match ty.as_mut() {
Type::Path(ty) => {
let mut segments = ty.path.segments.clone();
let last_segment = segments.last_mut().unwrap();
if let PathArguments::AngleBracketed(ty) = &mut last_segment.arguments {
if let GenericArgument::Lifetime(lifetime) =
ty.args.first_mut().unwrap()
{
if lifetime.ident == format_ident!("_") {
lifetime.ident = format_ident!("static");
}
}
}
quote! { -> #segments }
}
Type::Reference(ref mut ty) => {
if ty.lifetime.is_none() {
ty.lifetime = Some(Lifetime::new("'static", Span::call_site()));
}
quote! { -> #ty }
}
ty => quote! { -> #ty },
},
};
let closure_type = quote! {
Fn(usize, #(#method_arg_types),*) #method_output + 'static
};
let attr = quote! {
#attr_ident: Option<(Box<dyn #closure_type>, std::sync::atomic::AtomicUsize)>
};
let attr_init = quote! {
#attr_ident: None
};
let fns = quote! {
pub fn #count_calls_of_fn_ident(&self) -> usize {
self.#attr_ident.as_ref()
.map(|stub| stub.1.load(std::sync::atomic::Ordering::Relaxed))
.unwrap_or_default()
}
pub fn #with_stub_of_fn_ident<F: #closure_type>(mut self, f: F) -> Self {
self.#attr_ident = Some((Box::new(f), std::sync::atomic::AtomicUsize::new(0)));
self
}
};
let stub_fn = quote! {
fn #method_ident(#method_inputs) #method_output {
match &self.#attr_ident {
Some(stub) => {
let i = stub.1.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
stub.0(i, #(#method_arg_names),*)
}
None => panic!("unexpected invocation of {}", stringify!(#method_ident)),
}
}
};
attrs.push(attr);
attrs_init.push(attr_init);
impl_fns.push(fns);
stub_fns.push(stub_fn);
}
}
let expanded = quote! {
#item_trait
pub struct #stub_struct_ident<#trait_generic_params> {
#(#attrs),*
}
impl<#trait_generic_params> #stub_struct_ident<#trait_generic_params> {
pub fn new() -> Self {
Self {
#(#attrs_init),*
}
}
#(#impl_fns)*
}
impl<#trait_generic_params> Default for #stub_struct_ident<#trait_generic_params> {
fn default() -> Self {
Self::new()
}
}
impl<#trait_generic_params> std::fmt::Debug for #stub_struct_ident<#trait_generic_params> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct(stringify!(#stub_struct_ident)).finish()
}
}
impl<#trait_generic_params> #trait_ident<#trait_generic_params> for #stub_struct_ident<#trait_generic_params> {
#(#stub_fns)*
}
};
TokenStream::from(expanded)
}