use heck::{CamelCase, SnakeCase};
use proc_macro::TokenStream;
use proc_macro2::{Literal, Span, TokenStream as TokenStream2};
use quote::quote;
use std::collections::HashSet;
use syn::*;
use super::utils::*;
pub fn derive(item: TokenStream) -> TokenStream {
let item = parse_macro_input!(item as DeriveInput);
let source = parse_parent();
let name = &item.ident;
let modname = Ident::new(
format!("{}_call", name).to_snake_case().as_str(),
Span::call_site(),
);
let (call_enum_tokens, call_enum) = create_call_enum(&item, &source);
let (call_impl_tokens, _) = create_call_impl(&item, &source, &call_enum);
let output = quote!(
use ::orga::macros::*;
pub mod #modname {
use super::*;
#call_enum_tokens
#call_impl_tokens
}
);
output.into()
}
pub fn attr(_args: TokenStream, input: TokenStream) -> TokenStream {
let method = parse_macro_input!(input as ImplItemMethod);
if !matches!(method.vis, Visibility::Public(_)) {
panic!("Call methods must be public");
}
if method.sig.unsafety.is_some() {
panic!("Call methods cannot be unsafe");
}
if method.sig.asyncness.is_some() {
panic!("Call methods cannot be async");
}
if method.sig.abi.is_some() {
panic!("Call methods cannot specify ABI");
}
quote!(#method).into()
}
pub(super) fn create_call_impl(
item: &DeriveInput,
source: &File,
call_enum: &ItemEnum,
) -> (TokenStream2, ItemImpl) {
let name = &item.ident;
let generics = &item.generics;
let mut generics_sanitized = generics.clone();
generics_sanitized.params.iter_mut().for_each(|g| {
if let GenericParam::Type(ref mut t) = g {
t.default = None;
}
});
let generic_params = gen_param_input(generics, true);
let call_type = &call_enum.ident;
let call_generics = &call_enum.generics;
let where_preds = item.generics.where_clause.as_ref().map(|w| &w.predicates);
let encoding_bounds = relevant_methods(name, "call", source)
.into_iter()
.flat_map(|(method, _)| {
let inputs: Vec<_> = method
.sig
.inputs
.iter()
.skip(1)
.map(|input| match input {
FnArg::Typed(input) => *input.ty.clone(),
_ => panic!("unexpected input"),
})
.collect();
get_generic_requirements(
inputs.iter().cloned(),
item.generics.params.iter().cloned(),
)
})
.map(|p| quote!(#p: ::orga::encoding::Encode + ::orga::encoding::Decode + ::orga::encoding::Terminated,));
let encoding_bounds = quote!(#(#encoding_bounds)*);
let call_bounds = relevant_methods(name, "call", source)
.into_iter()
.map(|(method, _)| {
let unit_tuple: Type = parse2(quote!(())).unwrap();
match method.sig.output {
ReturnType::Type(_, ref ty) => *(ty.clone()),
ReturnType::Default => unit_tuple,
}
})
.flat_map(|ty| {
get_generic_requirements(
std::iter::once(&ty).cloned(),
generics.params.iter().cloned(),
)
})
.map(|t| quote!(#t: ::orga::call::Call,));
let call_bounds = quote!(#(#call_bounds)*);
let fields = match &item.data {
Data::Struct(data) => data.fields.iter(),
Data::Enum(_) => todo!("Enums are not supported yet"),
Data::Union(_) => panic!("Unions are not supported"),
};
let field_call_arms: Vec<_> = fields
.filter(|field| matches!(field.vis, Visibility::Public(_)))
.enumerate()
.map(|(i, field)| {
let variant_name = field.ident.as_ref().map_or(
Ident::new(format!("Field{}", i).as_str(), Span::call_site()),
|f| {
Ident::new(
format!("Field{}", f.to_string().to_camel_case()).as_str(),
Span::call_site(),
)
},
);
let field_name = field.ident.as_ref().map_or_else(
|| {
let i = Literal::usize_unsuffixed(i);
quote!(#i)
},
|f| quote!(#f),
);
quote! {
Call::#variant_name(subcall) => {
::orga::call::maybe_call(&mut self.#field_name, subcall)
}
}
})
.collect();
let mut maybe_call_defs = vec![];
let method_call_arms: Vec<_> = relevant_methods(name, "call", source)
.into_iter()
.map(|(method, parent)| {
let method_name = &method.sig.ident;
let name_camel = method_name.to_string().to_camel_case();
let variant_name =
Ident::new(format!("Method{}", &name_camel).as_str(), Span::call_site());
let inputs: Vec<_> = (1..method.sig.inputs.len())
.into_iter()
.map(|i| Ident::new(format!("var{}", i).as_str(), Span::call_site()))
.collect();
let input_types: Vec<_> = method
.sig
.inputs
.iter()
.skip(1)
.filter_map(|input| match input {
FnArg::Typed(input) => Some(*input.ty.clone()),
_ => None,
})
.collect();
let full_inputs = quote! {
#(, #inputs: #input_types)*, subcall: Vec<u8>
};
let unit_tuple: Type = parse2(quote!(())).unwrap();
let output_type = match method.sig.output {
ReturnType::Type(_, ref ty) => *(ty.clone()),
ReturnType::Default => unit_tuple,
};
let requirements = get_generic_requirements(
input_types
.iter()
.chain(std::iter::once(&output_type))
.cloned(),
generics.params.iter().cloned(),
);
let generic_reqs = if requirements.is_empty() {
quote!()
} else {
quote!(<#(#requirements),*>)
};
let parent_generics = &parent.generics;
let parent_where_preds = &parent.generics.where_clause.as_ref().map(|w| &w.predicates);
let trait_name = Ident::new(
format!("MaybeCall{}", &variant_name).as_str(),
Span::call_site(),
);
maybe_call_defs.push(quote! {
trait #trait_name#generic_reqs {
fn maybe_call(&mut self #full_inputs) -> ::orga::Result<()>;
}
impl<__Self, #(#requirements),*> #trait_name#generic_reqs for __Self {
default fn maybe_call(&mut self #full_inputs) -> ::orga::Result<()> {
Err(::orga::Error::Call("This call cannot be called because not all bounds are met".into()))
}
}
impl#parent_generics #trait_name#generic_reqs for #name#generic_params
where #where_preds #encoding_bounds #call_bounds #parent_where_preds
{
fn maybe_call(&mut self #full_inputs) -> ::orga::Result<()> {
let output = self.#method_name(#(#inputs),*);
::orga::call::maybe_call(output, subcall)
}
}
});
let dotted_generic_reqs = if generic_reqs.is_empty() {
quote!()
} else {
quote!(::#generic_reqs)
};
quote! {
Call::#variant_name(#(#inputs,)* subcall) => {
#trait_name#dotted_generic_reqs::maybe_call(self, #(#inputs,)* subcall)
}
}
})
.collect();
let impl_output = quote! {
impl#generics_sanitized ::orga::call::Call for #name#generic_params
where #where_preds #encoding_bounds
{
type Call = #call_type#call_generics;
fn call(&mut self, call: Self::Call) -> ::orga::Result<()> {
match call {
#call_type::Noop => Ok(()),
#(#field_call_arms),*
#(#method_call_arms),*
}
}
}
};
let output = quote! {
#impl_output
#(#maybe_call_defs)*
};
(output, syn::parse2(impl_output).unwrap())
}
pub(super) fn create_call_enum(item: &DeriveInput, source: &File) -> (TokenStream2, ItemEnum) {
let name = &item.ident;
let generics = &item.generics;
let mut generic_params = vec![];
let fields = match &item.data {
Data::Struct(data) => data.fields.iter(),
Data::Enum(_) => todo!("Enums are not supported yet"),
Data::Union(_) => panic!("Unions are not supported"),
};
let field_variants: Vec<_> = fields
.filter(|field| matches!(field.vis, Visibility::Public(_)))
.enumerate()
.map(|(i, field)| {
let name = field.ident.as_ref().map_or(
Ident::new(format!("Field{}", i).as_str(), Span::call_site()),
|f| {
Ident::new(
format!("Field{}", f.to_string().to_camel_case()).as_str(),
Span::call_site(),
)
},
);
let requirements = get_generic_requirements(
vec![field.ty.clone()].into_iter(),
generics.params.iter().cloned(),
);
generic_params.extend(requirements.clone());
quote!(#name(Vec<u8>))
})
.collect();
let method_variants: Vec<_> = relevant_methods(name, "call", source)
.into_iter()
.map(|(method, _)| {
let name_camel = method.sig.ident.to_string().to_camel_case();
let name = Ident::new(format!("Method{}", &name_camel).as_str(), Span::call_site());
let fields = if method.sig.inputs.len() == 1 {
quote!()
} else {
let inputs: Vec<_> = method
.sig
.inputs
.iter()
.skip(1)
.map(|input| match input {
FnArg::Typed(input) => *input.ty.clone(),
_ => panic!("unexpected input"),
})
.collect();
let requirements = get_generic_requirements(
inputs.iter().cloned(),
generics.params.iter().cloned(),
);
generic_params.extend(requirements);
quote! { #(#inputs),*, }
};
quote! {
#name(#fields Vec<u8>)
}
})
.collect();
let generic_params = if generic_params.is_empty() {
quote!()
} else {
let params: HashSet<_> = generic_params.into_iter().collect();
let params = params.into_iter();
quote!(<#(#params),*>)
};
let struct_output = quote! {
#[derive(::orga::encoding::Encode, ::orga::encoding::Decode)]
pub enum Call#generic_params {
Noop,
#(#field_variants,)*
#(#method_variants,)*
}
};
let output = quote! {
#struct_output
impl#generic_params Default for Call#generic_params {
fn default() -> Self {
Call::Noop
}
}
};
(output, syn::parse2(struct_output).unwrap())
}
fn gen_param_input(generics: &Generics, bracketed: bool) -> TokenStream2 {
let gen_params = generics.params.iter().map(|p| match p {
GenericParam::Type(p) => {
let ident = &p.ident;
quote!(#ident)
}
GenericParam::Lifetime(p) => {
let ident = &p.lifetime.ident;
quote!(#ident)
}
GenericParam::Const(p) => {
let ident = &p.ident;
quote!(#ident)
}
});
if gen_params.len() == 0 {
quote!()
} else if bracketed {
quote!(<#(#gen_params),*>)
} else {
quote!(#(#gen_params),*)
}
}