extern crate proc_macro;
use inflector::Inflector;
use proc_macro::TokenStream as RawTokenStream;
use proc_macro2::{Ident, Span};
use quote::{format_ident, quote, ToTokens};
use syn::{
parse_macro_input, spanned::Spanned, AngleBracketedGenericArguments, Attribute, Lifetime, Meta,
MetaList, MetaNameValue, NestedMeta, PatType, Path, ReturnType, Signature, TraitItem,
TraitItemMethod, Type, TypePath,
};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum AsyncMethodType {
Ref,
Owned,
}
#[derive(Debug, Clone)]
struct MethodMeta {
ty: AsyncMethodType,
future_name: Option<String>,
async_method_name: Option<String>,
}
#[derive(Debug, Copy, Clone)]
enum PollMethodReceiverType {
Ref,
MutRef,
Pinned,
}
fn extract_output_type(ret: &ReturnType) -> Result<&Type, RawTokenStream> {
match *ret {
syn::ReturnType::Type(_, ref ty) => match **ty {
syn::Type::Path(ref path) => {
let tail_segment = path.path.segments.last().unwrap();
if tail_segment.ident.to_string() != "Poll" {
return Err(syn::Error::new(
ret.span(),
"polling method must return a Poll value",
)
.to_compile_error()
.into());
}
let args = &tail_segment.arguments;
match *args {
syn::PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args: ref generics,
..
}) if generics.len() != 1 => Err(syn::Error::new(
args.span(),
"Poll return type should have exactly 1 generic parameter",
)
.to_compile_error()
.into()),
syn::PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args: ref generics,
..
}) => match *generics.first().unwrap() {
syn::GenericArgument::Type(ref ty) => Ok(ty),
_ => Err(syn::Error::new(
args.span(),
"Error parsing generics of Poll type",
)
.to_compile_error()
.into()),
},
_ => Err(syn::Error::new(
ret.span(),
"Poll return type must include the <Output> type",
)
.to_compile_error()
.into()),
}
}
_ => Err(
syn::Error::new(ret.span(), "polling method must return a Poll value")
.to_compile_error()
.into(),
),
},
_ => Err(
syn::Error::new(ret.span(), "polling method must return a Poll value")
.to_compile_error()
.into(),
),
}
}
fn extract_poll_self_type(sig: &Signature) -> Option<PollMethodReceiverType> {
match *sig.inputs.first()? {
syn::FnArg::Receiver(ref recv) => {
if recv.reference.is_none() {
None
} else if recv.mutability.is_some() {
Some(PollMethodReceiverType::MutRef)
} else {
Some(PollMethodReceiverType::Ref)
}
}
syn::FnArg::Typed(PatType {
ref pat, ref ty, ..
}) => {
let pat_ident = match **pat {
syn::Pat::Ident(ref pat_ident) => pat_ident,
_ => return None,
};
if pat_ident.by_ref.is_some() || pat_ident.subpat.is_some() {
return None;
}
if pat_ident.ident != "self" {
return None;
}
let ty = match **ty {
Type::Path(TypePath {
qself: None,
path: Path { ref segments, .. },
}) => segments.last()?,
_ => return None,
};
if ty.ident != "Pin" {
return None;
}
let generics = match ty.arguments {
syn::PathArguments::AngleBracketed(ref generics) => &generics.args,
_ => return None,
};
if generics.len() != 1 {
return None;
}
let ty = match generics.first()? {
syn::GenericArgument::Type(Type::Reference(ty)) => ty,
_ => return None,
};
if ty.mutability.is_none() {
return None;
}
let self_ident = match *ty.elem {
Type::Path(TypePath {
qself: None,
ref path,
}) => path.get_ident()?,
_ => return None,
};
if self_ident != "Self" {
return None;
}
Some(PollMethodReceiverType::Pinned)
}
}
}
fn extract_meta<'a>(attrs: &'a mut Vec<Attribute>) -> Option<Result<MethodMeta, RawTokenStream>> {
for (index, attr) in attrs.iter_mut().enumerate() {
let meta = match attr.parse_meta() {
Ok(meta) => meta,
Err(..) => continue,
};
let (path, nested) = match meta {
syn::Meta::Path(path) => (path, None),
syn::Meta::List(MetaList { path, nested, .. }) => (path, Some(nested)),
_ => continue,
};
match path.get_ident() {
Some(ident) if ident == "async_method" => {}
_ => continue,
}
attrs.remove(index);
let mut result = MethodMeta {
ty: AsyncMethodType::Ref,
async_method_name: None,
future_name: None,
};
if let Some(meta_args) = nested {
for arg in meta_args.iter() {
match arg {
NestedMeta::Meta(Meta::NameValue(MetaNameValue {
path,
lit: syn::Lit::Str(name),
..
})) => {
let ident = match path.get_ident() {
Some(ident) => ident,
None => {
return Some(Err(syn::Error::new(
path.span(),
"Unrecognized meta argument",
)
.to_compile_error()
.into()))
}
};
if ident == "method_name" {
result.async_method_name = Some(name.value())
} else if ident == "future_name" {
result.future_name = Some(name.value())
} else {
return Some(Err(syn::Error::new(
path.span(),
"Unrecognized meta argument",
)
.to_compile_error()
.into()));
}
}
NestedMeta::Meta(Meta::Path(path)) => {
let ident = match path.get_ident() {
Some(ident) => ident,
None => {
return Some(Err(syn::Error::new(
path.span(),
"Unrecognized meta argument",
)
.to_compile_error()
.into()))
}
};
if ident == "owned" {
result.ty = AsyncMethodType::Owned;
} else {
return Some(Err(syn::Error::new(
path.span(),
"Unrecognized meta argument",
)
.to_compile_error()
.into()));
}
}
_ => {
return Some(Err(syn::Error::new(
arg.span(),
"Unrecognized meta argument",
)
.to_compile_error()
.into()))
}
}
}
}
return Some(Ok(result));
}
None
}
#[proc_macro_attribute]
pub fn async_poll_trait(_attr: RawTokenStream, item: RawTokenStream) -> RawTokenStream {
let mut parsed = parse_macro_input!(item as syn::ItemTrait);
let trait_ident = &parsed.ident;
let trait_name = trait_ident.to_string();
let vis = &parsed.vis;
let mut new_methods = Vec::new();
let mut new_structs = Vec::new();
for item in &mut parsed.items {
let method = match item {
TraitItem::Method(method) => method,
_ => continue,
};
let meta = match extract_meta(&mut method.attrs) {
None => continue,
Some(Err(err)) => return err,
Some(Ok(meta)) => meta,
};
let output_type = match extract_output_type(&method.sig.output) {
Ok(ty) => ty,
Err(err) => return err,
};
let receiver_type =
match extract_poll_self_type(&method.sig) {
Some(receiver_type) => receiver_type,
None => return syn::Error::new(
method.sig.span(),
"poll function must be a method that takes &self, &mut self, or Pin<&mut Self>",
)
.to_compile_error()
.into(),
};
let poll_method_ident = &method.sig.ident;
let poll_method_name = poll_method_ident.to_string();
let base_name = poll_method_name.strip_prefix("poll_");
let async_method_name = match meta.async_method_name.as_deref().or(base_name) {
Some(name) => name,
None => {
return syn::Error::new(
poll_method_ident.span(),
"poll method must start with poll_",
)
.to_compile_error()
.into()
}
};
let async_method_ident = Ident::new(
async_method_name,
Span::call_site().resolved_at(poll_method_ident.span()),
);
let future_name = match meta
.future_name
.or_else(|| base_name.map(|name| format!("{}{}", trait_name, name.to_class_case())))
{
Some(name) => name,
None => {
return syn::Error::new(
poll_method_ident.span(),
"poll method must start with poll_",
)
.to_compile_error()
.into()
}
};
let future_ident = Ident::new(
future_name.as_str(),
Span::call_site().resolved_at(trait_ident.span()),
);
let self_ident = format_ident!("self");
let cx_ident = format_ident!("cx");
let inner_ident = format_ident!("inner");
let generic_ident = format_ident!("T");
let generic_lt = Lifetime::new("'a", Span::call_site());
let (async_def, future_def) = match meta.ty {
AsyncMethodType::Owned => {
let async_method_definition = quote! {
fn #async_method_ident(self) -> #future_ident<Self>
where Self: Sized
{
#future_ident { #inner_ident: self }
}
};
let future_poll_definition = match receiver_type {
PollMethodReceiverType::MutRef => quote! {
unsafe { #self_ident.get_unchecked_mut() }.#inner_ident.#poll_method_ident(#cx_ident)
},
PollMethodReceiverType::Ref => quote! {
#self_ident.into_ref().get_ref().#inner_ident.#poll_method_ident(#cx_ident)
},
PollMethodReceiverType::Pinned => quote! {
unsafe { Pin::new_unchecked(&mut #self_ident.get_unchecked_mut().#inner_ident) }.#poll_method_ident(#cx_ident)
},
};
let future_definition = quote! {
#[derive(Debug)]
#vis struct #future_ident<T: #trait_ident> {
#inner_ident: T,
}
impl<T: #trait_ident> ::core::future::Future for #future_ident<T> {
type Output = #output_type;
fn poll(
#self_ident: ::core::pin::Pin<&mut Self>,
#cx_ident: &mut ::core::task::Context<'_>,
) -> ::core::task::Poll<Self::Output>
{
#future_poll_definition
}
}
};
(async_method_definition, future_definition)
}
AsyncMethodType::Ref => {
let async_method_receiver = match receiver_type {
PollMethodReceiverType::Ref => quote! { &#self_ident },
PollMethodReceiverType::MutRef => quote! { &mut #self_ident },
PollMethodReceiverType::Pinned => {
quote! { #self_ident: ::core::pin::Pin<&mut Self> }
}
};
let async_method_definition = quote! {
fn #async_method_ident(#async_method_receiver) -> #future_ident<Self> {
#future_ident { #inner_ident: #self_ident }
}
};
let future_inner_type = match receiver_type {
PollMethodReceiverType::Ref => quote! {& #generic_lt #generic_ident },
PollMethodReceiverType::MutRef => quote! { & #generic_lt mut #generic_ident },
PollMethodReceiverType::Pinned => {
quote! { Pin<& #generic_lt mut #generic_ident> }
}
};
let future_poll_definition = match receiver_type {
PollMethodReceiverType::Ref | PollMethodReceiverType::MutRef => quote! {
#self_ident.get_mut().#inner_ident.#poll_method_ident(#cx_ident)
},
PollMethodReceiverType::Pinned => quote! {
#self_ident.get_mut().#inner_ident.as_mut().#poll_method_ident(#cx_ident)
},
};
let future_definition = quote! {
#[derive(Debug)]
#vis struct #future_ident<#generic_lt, #generic_ident: #trait_ident + ?Sized> {
#inner_ident: #future_inner_type,
}
impl<'a, T: #trait_ident + ?Sized> ::core::future::Future for #future_ident<'a, T> {
type Output = #output_type;
fn poll(
#self_ident: ::core::pin::Pin<&mut Self>,
#cx_ident: &mut ::core::task::Context<'_>,
) -> ::core::task::Poll<Self::Output>
{
#future_poll_definition
}
}
};
(async_method_definition, future_definition)
}
};
let async_def = async_def.into();
let async_def = parse_macro_input!(async_def as TraitItemMethod);
new_methods.push(async_def);
new_structs.push(future_def);
}
parsed
.items
.extend(new_methods.into_iter().map(TraitItem::Method));
let mut output = parsed.into_token_stream();
output.extend(new_structs);
output.into()
}