use proc_macro::TokenStream;
use quote::{format_ident, quote, ToTokens};
use syn::FnArg::Typed;
use syn::__private::Span;
use syn::__private::{str, Default};
use syn::parse::Parser;
use syn::punctuated::Punctuated;
use syn::{
parse_macro_input, Block, Expr, ExprCall, ExprForLoop, ExprMatch, FnArg, GenericParam, Ident,
ImplItem, ImplItemMethod, ItemEnum, ItemImpl, Lit, MetaList, NestedMeta, Pat, ReturnType,
Signature, Stmt, Type,
};
use std::collections::HashSet;
#[proc_macro_attribute]
pub fn invoke_impl(args: TokenStream, item: TokenStream) -> TokenStream {
let mut input = parse_macro_input!(item as ItemImpl);
let (name, clones) = parse_args(args);
let methods = input
.items
.iter()
.filter_map(|item| match item {
ImplItem::Method(method) => Some(method),
_ => None,
})
.collect::<Vec<_>>();
let count = methods.len();
let names = methods
.iter()
.map(|iim| iim.sig.ident.to_string())
.collect::<Vec<_>>();
validate_signatures(methods[0], &methods);
let struct_ident = get_struct_identifier_as_path(&input).unwrap();
let enum_tokenstream = create_enum(&methods, &struct_ident, &name);
let invoke_all = create_invoke_function(
methods[0],
&methods,
&struct_ident,
InvokeType::All,
&name,
&clones,
);
let invoke_subset = create_invoke_function(
methods[0],
&methods,
&struct_ident,
InvokeType::Subset,
&name,
&clones,
);
let invoke_all_enumerated = create_invoke_function(
methods[0],
&methods,
&struct_ident,
InvokeType::SpecifiedAll(SpecificationType::Enumerated),
&name,
&clones,
);
let invoke_all_enum = create_invoke_function(
methods[0],
&methods,
&struct_ident,
InvokeType::SpecifiedAll(SpecificationType::Enum),
&name,
&clones,
);
let invoke_enumerated = create_invoke_function(
methods[0],
&methods,
&struct_ident,
InvokeType::Specified(SpecificationType::Enumerated),
&name,
&clones,
);
let invoke_enum = create_invoke_function(
methods[0],
&methods,
&struct_ident,
InvokeType::Specified(SpecificationType::Enum),
&name,
&clones,
);
input.items.push(invoke_all);
input.items.push(invoke_subset);
input.items.push(invoke_all_enumerated);
input.items.push(invoke_all_enum);
input.items.push(invoke_enumerated);
input.items.push(invoke_enum);
let mc_ident = if let Some(ref s) = name {
format_ident!("METHOD_COUNT_{}", s)
} else {
format_ident!("METHOD_COUNT")
};
input
.items
.push(syn::parse(quote!(pub const #mc_ident: usize = #count;).into()).unwrap());
let ml_ident = if let Some(ref s) = name {
format_ident!("METHOD_LIST_{}", s)
} else {
format_ident!("METHOD_LIST")
};
input.items.push(
syn::parse(quote!(pub const #ml_ident: [&'static str; #count] = [#(#names),*];).into())
.unwrap(),
);
let mut revised_impl: TokenStream = input.into_token_stream().into();
revised_impl.extend(enum_tokenstream);
revised_impl
}
#[derive(Copy, Clone)]
enum SpecificationType {
Enum,
Enumerated,
}
#[derive(Copy, Clone)]
enum InvokeType {
Specified(SpecificationType),
SpecifiedAll(SpecificationType),
Subset,
All,
}
fn create_invoke_function(
base_method: &ImplItemMethod,
methods: &Vec<&ImplItemMethod>,
struct_ident: &Ident,
invoke_type: InvokeType,
name: &Option<String>,
clone: &Option<HashSet<usize>>,
) -> ImplItem {
let output_type = base_method.sig.output.clone();
let invoke_name = generate_invoke_name(name, invoke_type);
let enum_name = generate_enum_name(struct_ident, name);
let mut invoke_sig = Signature {
ident: invoke_name,
output: ReturnType::Default,
..base_method.sig.clone()
};
let mut is_method = false;
let param_ids = invoke_sig
.inputs
.iter()
.cloned()
.enumerate()
.filter_map(|(index, fnarg)| match fnarg {
FnArg::Receiver(receiver) => {
if receiver.reference.is_some() {
is_method = true;
} else {
panic!("invoke_impl cannot be used with methods taking self as move!");
}
None
}
Typed(pattype) => Some((index, pattype)),
})
.filter_map(|(index, pat)| match *pat.pat {
Pat::Ident(patident) => Some({
let id = patident.ident;
if let Some(hs) = clone {
if hs.contains(&index) {
Expr::MethodCall(syn::parse(quote!(#id.clone()).into()).unwrap())
} else {
Expr::Path(syn::parse(quote!(#id).into()).unwrap())
}
} else {
Expr::Path(syn::parse(quote!(#id).into()).unwrap())
}
}),
_ => None,
})
.collect::<Vec<_>>();
let generic_params = invoke_sig
.generics
.params
.iter()
.cloned()
.filter_map(|gp| match gp {
GenericParam::Type(tp) => Some(tp.ident),
_ => None,
})
.collect::<Vec<_>>();
let closure_ident = Ident::new("consumer", Span::call_site());
if output_type != generate_trailing_return_type() && output_type != ReturnType::Default {
let arg = if let ReturnType::Type(_, bx) = output_type.clone() {
let bxtype = *bx;
match invoke_type {
InvokeType::Specified(st) | InvokeType::SpecifiedAll(st) => match st {
SpecificationType::Enum => syn::parse(
quote!(mut #closure_ident: impl FnMut(#enum_name, #bxtype)).into(),
)
.unwrap(),
SpecificationType::Enumerated => {
syn::parse(quote!(mut #closure_ident: impl FnMut(usize, #bxtype)).into())
.unwrap()
}
},
InvokeType::All | InvokeType::Subset => {
syn::parse(quote!(mut #closure_ident: impl FnMut(#bxtype)).into()).unwrap()
}
}
} else {
panic!("Shouldn't detect an empty return after the if statement!")
};
invoke_sig.inputs.push(arg);
} else {
let arg = match invoke_type {
InvokeType::Specified(st) | InvokeType::SpecifiedAll(st) => match st {
SpecificationType::Enum => Some(
syn::parse(quote!(mut #closure_ident: impl FnMut(#enum_name)).into()).unwrap(),
),
SpecificationType::Enumerated => {
Some(syn::parse(quote!(mut #closure_ident: impl FnMut(usize)).into()).unwrap())
}
},
InvokeType::Subset | InvokeType::All => None,
};
if let Some(fnarg) = arg {
invoke_sig.inputs.push(fnarg);
}
}
let specifier = match invoke_type {
InvokeType::Specified(st) => match st {
SpecificationType::Enum => Some(
syn::parse(quote!(mut invoke_impl_iter: impl Iterator<Item=#enum_name>).into())
.unwrap(),
),
SpecificationType::Enumerated => Some(
syn::parse(quote!(mut invoke_impl_iter: impl Iterator<Item=usize>).into()).unwrap(),
),
},
InvokeType::Subset => Some(
syn::parse(quote!(mut invoke_impl_iter: impl Iterator<Item=usize>).into()).unwrap(),
),
InvokeType::All | InvokeType::SpecifiedAll(_) => None,
};
if let Some(fnarg) = specifier {
invoke_sig.inputs.push(fnarg);
}
let invoke_block = match invoke_type {
InvokeType::Specified(st) => invoke_enum_block(
is_method,
st,
&output_type,
methods,
&closure_ident,
&struct_ident,
name,
&generic_params,
¶m_ids,
),
InvokeType::SpecifiedAll(st) => invoke_all_enum_block(
is_method,
st,
&output_type,
methods,
&closure_ident,
&struct_ident,
name,
&generic_params,
¶m_ids,
),
InvokeType::Subset => invoke_some_block(
is_method,
&output_type,
methods,
&closure_ident,
&struct_ident,
&generic_params,
¶m_ids,
),
InvokeType::All => invoke_all_block(
is_method,
&output_type,
methods,
&closure_ident,
&struct_ident,
&generic_params,
¶m_ids,
),
};
ImplItem::Method(ImplItemMethod {
sig: invoke_sig,
block: invoke_block,
..base_method.clone()
})
}
fn invoke_all_block(
is_method: bool,
output_type: &ReturnType,
methods: &Vec<&ImplItemMethod>,
closure_ident: &Ident,
struct_ident: &Ident,
generic_params: &Vec<Ident>,
param_ids: &Vec<Expr>,
) -> Block {
let mut invoke_block = Block {
brace_token: Default::default(),
stmts: vec![],
};
for &method in methods {
let inner_call =
get_inner_call_expr(is_method, method, struct_ident, generic_params, param_ids);
if output_type != &generate_trailing_return_type() && output_type != &ReturnType::Default {
let outer_call: ExprCall =
syn::parse(quote!(#closure_ident(#inner_call)).into()).unwrap();
invoke_block
.stmts
.push(Stmt::Semi(Expr::Call(outer_call), Default::default()));
} else {
invoke_block
.stmts
.push(Stmt::Semi(inner_call, Default::default()));
}
}
invoke_block
}
fn invoke_some_block(
is_method: bool,
output_type: &ReturnType,
methods: &Vec<&ImplItemMethod>,
closure_ident: &Ident,
struct_ident: &Ident,
generic_params: &Vec<Ident>,
param_ids: &Vec<Expr>,
) -> Block {
let mut invoke_block = Block {
brace_token: Default::default(),
stmts: vec![],
};
let mut match_statement: ExprMatch = syn::parse(quote!(match invoke_impl_i {}).into()).unwrap();
for (index, &method) in methods.into_iter().enumerate() {
let inner_call =
get_inner_call_expr(is_method, method, struct_ident, generic_params, param_ids);
let outer_call = if output_type != &generate_trailing_return_type()
&& output_type != &ReturnType::Default
{
syn::parse(quote!(#closure_ident(#inner_call)).into()).unwrap()
} else {
inner_call
};
match_statement
.arms
.push(syn::parse(quote!(#index => #outer_call,).into()).unwrap());
}
match_statement.arms.push(
syn::parse(quote!(_ => panic!("Iter contains invalid function index!")).into()).unwrap(),
);
let loopexpr: ExprForLoop = syn::parse(
quote!(for invoke_impl_i in invoke_impl_iter {
#match_statement
})
.into(),
)
.unwrap();
invoke_block.stmts.push(Stmt::Expr(Expr::ForLoop(loopexpr)));
invoke_block
}
fn invoke_all_enum_block(
is_method: bool,
specification_type: SpecificationType,
output_type: &ReturnType,
methods: &Vec<&ImplItemMethod>,
closure_ident: &Ident,
struct_ident: &Ident,
name: &Option<String>,
generic_params: &Vec<Ident>,
param_ids: &Vec<Expr>,
) -> Block {
let mut invoke_block = Block {
brace_token: Default::default(),
stmts: vec![],
};
let enum_name = generate_enum_name(struct_ident, name);
let identifiers = methods
.into_iter()
.map(|im| im.sig.ident.clone())
.collect::<Vec<_>>();
for (index, (enum_ident, &method)) in
identifiers.into_iter().zip(methods.into_iter()).enumerate()
{
let inner_call =
get_inner_call_expr(is_method, method, struct_ident, generic_params, param_ids);
let outer_call = if output_type != &generate_trailing_return_type()
&& output_type != &ReturnType::Default
{
match specification_type {
SpecificationType::Enum => {
syn::parse(quote!(#closure_ident(#enum_name::#enum_ident, #inner_call)).into())
.unwrap()
}
SpecificationType::Enumerated => {
syn::parse(quote!(#closure_ident(#index, #inner_call)).into()).unwrap()
}
}
} else {
invoke_block
.stmts
.push(Stmt::Semi(inner_call, Default::default()));
match specification_type {
SpecificationType::Enum => {
syn::parse(quote!(#closure_ident(#enum_name::#enum_ident)).into()).unwrap()
}
SpecificationType::Enumerated => {
syn::parse(quote!(#closure_ident(#index)).into()).unwrap()
}
}
};
invoke_block
.stmts
.push(Stmt::Semi(Expr::Call(outer_call), Default::default()));
}
invoke_block
}
fn invoke_enum_block(
is_method: bool,
specification_type: SpecificationType,
output_type: &ReturnType,
methods: &Vec<&ImplItemMethod>,
closure_ident: &Ident,
struct_ident: &Ident,
name: &Option<String>,
generic_params: &Vec<Ident>,
param_ids: &Vec<Expr>,
) -> Block {
let mut invoke_block = Block {
brace_token: Default::default(),
stmts: vec![],
};
let enum_name = generate_enum_name(struct_ident, name);
let identifiers = methods
.into_iter()
.map(|im| im.sig.ident.clone())
.collect::<Vec<_>>();
let mut match_statement: ExprMatch = syn::parse(quote!(match invoke_impl_i {}).into()).unwrap();
for (index, (enum_ident, &method)) in
identifiers.into_iter().zip(methods.into_iter()).enumerate()
{
let inner_call =
get_inner_call_expr(is_method, method, struct_ident, generic_params, param_ids);
let outer_call: Expr = if output_type != &generate_trailing_return_type()
&& output_type != &ReturnType::Default
{
match specification_type {
SpecificationType::Enum => syn::parse(
quote!({#closure_ident(#enum_name::#enum_ident, #inner_call);}).into(),
)
.unwrap(),
SpecificationType::Enumerated => {
syn::parse(quote!({#closure_ident(#index, #inner_call);}).into()).unwrap()
}
}
} else {
match specification_type {
SpecificationType::Enum => syn::parse(
quote!({
#inner_call;
#closure_ident(#enum_name::#enum_ident);
})
.into(),
)
.unwrap(),
SpecificationType::Enumerated => syn::parse(
quote!({
#inner_call;
#closure_ident(#index);
})
.into(),
)
.unwrap(),
}
};
match specification_type {
SpecificationType::Enum => {
match_statement.arms.push(
syn::parse(quote!(#enum_name::#enum_ident => #outer_call,).into()).unwrap(),
);
}
SpecificationType::Enumerated => {
match_statement
.arms
.push(syn::parse(quote!(#index => #outer_call,).into()).unwrap());
}
}
}
match specification_type {
SpecificationType::Enum => {}
SpecificationType::Enumerated => {
match_statement.arms.push(
syn::parse(quote!(_ => panic!("Iter contains invalid function index!")).into())
.unwrap(),
);
}
}
let loopexpr: ExprForLoop = syn::parse(
quote!(for invoke_impl_i in invoke_impl_iter {
#match_statement
})
.into(),
)
.unwrap();
invoke_block.stmts.push(Stmt::Expr(Expr::ForLoop(loopexpr)));
invoke_block
}
fn get_inner_call_expr(
is_method: bool,
method: &ImplItemMethod,
struct_ident: &Ident,
generic_params: &Vec<Ident>,
param_ids: &Vec<Expr>,
) -> Expr {
let method_name = method.sig.ident.clone();
if is_method {
Expr::MethodCall(
syn::parse(quote!(self.#method_name::<#(#generic_params),*>(#(#param_ids),*)).into())
.unwrap(),
)
} else {
Expr::Call(
syn::parse(
quote!(#struct_ident::#method_name::<#(#generic_params),*>(#(#param_ids),*)).into(),
)
.unwrap(),
)
}
}
fn create_enum(methods: &Vec<&ImplItemMethod>, struct_ident: &Ident, name: &Option<String>) -> TokenStream {
let identifiers = methods
.into_iter()
.map(|im| im.sig.ident.clone())
.collect::<Vec<_>>();
let names = identifiers
.iter()
.map(|i| i.to_string())
.collect::<Vec<_>>();
let num_members = identifiers.len();
let enum_name = generate_enum_name(struct_ident, name);
let enum_declaration: ItemEnum = syn::parse(
quote!(
#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy)]
pub enum #enum_name {
#(#identifiers),*
})
.into(),
)
.unwrap();
let enum_impl: ItemImpl = syn::parse(
quote!(impl #enum_name {
pub fn iter() -> impl Iterator<Item=&'static #enum_name> {
use #enum_name::*;
static members: [#enum_name; #num_members] = [#(#identifiers),*];
members.iter()
}
})
.into(),
)
.unwrap();
let try_from_str: ItemImpl = syn::parse(
quote!(
impl TryFrom<&str> for #enum_name {
type Error = &'static str;
fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
#(#names => Ok(Self::#identifiers),)*
_ => Err("Input str does not match any enums in Self!")
}
}
}
)
.into(),
)
.unwrap();
let from_num: ItemImpl = syn::parse(
quote!(
impl From<#enum_name> for &str {
fn from(en: #enum_name) -> Self {
use #enum_name::*;
match en {
#(#identifiers => #names,)*
}
}
}
)
.into(),
)
.unwrap();
let mut enum_tokenstream: TokenStream = enum_declaration.into_token_stream().into();
enum_tokenstream.extend::<TokenStream>(enum_impl.into_token_stream().into());
enum_tokenstream.extend::<TokenStream>(try_from_str.into_token_stream().into());
enum_tokenstream.extend::<TokenStream>(from_num.into_token_stream().into());
enum_tokenstream
}
fn validate_signatures(base_method: &ImplItemMethod, methods: &Vec<&ImplItemMethod>) {
let base_signature = Signature {
ident: Ident::new("name", Span::call_site()),
..base_method.sig.clone()
};
let method_comparison = ImplItemMethod {
sig: base_signature,
attrs: vec![],
block: Block {
brace_token: Default::default(),
stmts: vec![],
},
..base_method.clone()
};
for &method in methods {
let signature = Signature {
ident: Ident::new("name", Span::call_site()),
..method.sig.clone()
};
let methodimpl = ImplItemMethod {
sig: signature,
attrs: vec![],
block: Block {
brace_token: Default::default(),
stmts: vec![],
},
..method.clone()
};
if method_comparison != methodimpl {
panic!(
"ImplItemMethods different! \
Base Method: {:?} \
Method: {:?}",
method_comparison.to_token_stream().to_string(),
methodimpl.to_token_stream().to_string()
);
}
}
}
fn get_struct_identifier_as_path(input: &ItemImpl) -> Result<Ident, &str> {
if let Type::Path(ref tp) = *input.self_ty {
Ok(tp.path.segments[0].ident.clone())
} else {
Err("No struct name detected!")
}
}
fn parse_args(args: TokenStream) -> (Option<String>, Option<HashSet<usize>>) {
let punctuated_args = Punctuated::<MetaList, syn::Token![;]>::parse_terminated
.parse(args)
.unwrap();
let mut result = (None, None);
if punctuated_args.is_empty() {
result
} else if punctuated_args.len() == 1 || punctuated_args.len() == 2 {
for arg in punctuated_args {
match arg
.path
.get_ident()
.cloned()
.unwrap()
.to_string()
.to_lowercase()
.as_str()
{
"name" => {
if result.0.is_some() {
panic!("Argument name passed to invoke_impl twice!")
}
if arg.nested.len() != 1 {
panic!("There can only be a single literal str argument to name!")
} else {
match &arg.nested[0] {
NestedMeta::Meta(_) => {
panic!("There can only be a single literal str argument to name!")
}
NestedMeta::Lit(lit) => {
match lit {
Lit::Str(litstr) => result.0 = Some(litstr.value()),
_ => {
panic!("There can only be a single literal str argument to name!")
}
}
}
}
}
}
"clone" => {
if result.1.is_some() {
panic!("Argument clone passed to invoke_impl twice!")
}
let mut indices = HashSet::new();
for nm in &arg.nested {
match nm {
NestedMeta::Meta(_) => {
panic!("Arguments to clone must be literal ints!")
}
NestedMeta::Lit(lit) => match lit {
Lit::Int(litint) => {
indices
.insert(litint.base10_digits().parse::<usize>().unwrap());
}
_ => {
panic!("Arguments to clone must be literal ints!")
}
},
}
}
result.1 = Some(indices);
}
_ => {
panic!("The only valid arguments to invoke_impl are name and clone!")
}
}
}
result
} else {
panic!(
"invoke_impl currently only supports args name and clone in the format \
#[invoke-impl(name(\"name\"); clone(2, 3, 4)], and more than two args were passed in!"
);
}
}
fn generate_invoke_name(name: &Option<String>, invoke_type: InvokeType) -> Ident {
let base_string = match invoke_type {
InvokeType::Specified(specifier) => match specifier {
SpecificationType::Enum => "invoke_enum",
SpecificationType::Enumerated => "invoke_enumerated",
},
InvokeType::SpecifiedAll(specifier) => match specifier {
SpecificationType::Enum => "invoke_all_enum",
SpecificationType::Enumerated => "invoke_all_enumerated",
},
InvokeType::All => "invoke_all",
InvokeType::Subset => "invoke_subset",
};
if let Some(name_s) = name {
format_ident!("{}_{}", base_string, name_s)
} else {
format_ident!("{}", base_string)
}
}
fn generate_enum_name(struct_ident: &Ident, name: &Option<String>) -> Ident {
if let Some(n) = name {
format_ident!("{}_invoke_impl_enum_{}", struct_ident, n)
} else {
format_ident!("{}_invoke_impl_enum", struct_ident)
}
}
fn generate_trailing_return_type() -> ReturnType {
let trailing_empty_return_type: ReturnType = syn::parse(quote!(-> ()).into()).unwrap();
trailing_empty_return_type
}