extern crate proc_macro;
use std::mem;
use attributes::AssociatedConstant;
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::{quote, ToTokens};
use syn::parse::{Parse, ParseStream};
use syn::spanned::Spanned;
use syn::visit_mut::VisitMut;
use syn::{parse_quote, Error, Expr, ExprField, ExprMethodCall, FnArg, GenericParam, Meta};
use crate::attributes::{
combine_attributes, parse_method_attributes, parse_segment_attributes, ReturnExpression,
SegmentAttributes, TargetSpecifier,
};
mod attributes;
mod kw {
syn::custom_keyword!(to);
syn::custom_keyword!(target);
}
#[derive(Clone)]
enum ArgumentModifier {
Into,
AsRef,
Newtype,
}
#[derive(Clone)]
enum DelegatedInput {
Input {
parameter: syn::FnArg,
modifier: Option<ArgumentModifier>,
},
Argument(syn::Expr),
}
fn get_argument_modifier(attribute: syn::Attribute) -> Result<ArgumentModifier, Error> {
if let Meta::Path(mut path) = attribute.meta {
if path.segments.len() == 1 {
let segment = path.segments.pop().unwrap();
if segment.value().arguments.is_empty() {
let ident = segment.value().ident.to_string();
let ident = ident.as_str();
match ident {
"into" => return Ok(ArgumentModifier::Into),
"as_ref" => return Ok(ArgumentModifier::AsRef),
"newtype" => return Ok(ArgumentModifier::Newtype),
_ => (),
}
}
}
};
panic!("The attribute argument has to be `into` or `as_ref`, like this: `#[into] a: u32`.")
}
impl syn::parse::Parse for DelegatedInput {
fn parse(input: ParseStream) -> Result<Self, Error> {
let lookahead = input.lookahead1();
if lookahead.peek(syn::token::Bracket) {
let content;
let _bracket_token = syn::bracketed!(content in input);
let expression: syn::Expr = content.parse()?;
Ok(Self::Argument(expression))
} else {
let (input, modifier) = if lookahead.peek(syn::token::Pound) {
let mut attributes = input.call(tolerant_outer_attributes)?;
if attributes.len() > 1 {
panic!("You can specify at most a single attribute for each parameter in a delegated method");
}
let modifier = get_argument_modifier(attributes.pop().unwrap())
.expect("Could not parse argument modifier attribute");
let input: syn::FnArg = input.parse()?;
(input, Some(modifier))
} else {
(input.parse()?, None)
};
Ok(Self::Input {
parameter: input,
modifier,
})
}
}
}
struct DelegatedMethod {
method: syn::TraitItemFn,
attributes: Vec<syn::Attribute>,
visibility: syn::Visibility,
arguments: syn::punctuated::Punctuated<syn::Expr, syn::Token![,]>,
}
fn parse_input_into_argument_expression(
function_name: &Ident,
input: &syn::FnArg,
) -> Option<syn::Expr> {
match input {
syn::FnArg::Typed(typed) => {
match &*typed.pat {
syn::Pat::Ident(ident) if ident.ident == "self" => None,
syn::Pat::Ident(ident) => {
let path_segment = syn::PathSegment {
ident: ident.ident.clone(),
arguments: syn::PathArguments::None,
};
let mut segments = syn::punctuated::Punctuated::new();
segments.push(path_segment);
let path = syn::Path {
leading_colon: None,
segments,
};
let ident_as_expr = syn::Expr::from(syn::ExprPath {
attrs: Vec::new(),
qself: None,
path,
});
Some(ident_as_expr)
}
_ => panic!(
"You have to use simple identifiers for delegated method parameters ({})",
function_name ),
}
}
syn::FnArg::Receiver(_receiver) => None,
}
}
impl syn::parse::Parse for DelegatedMethod {
fn parse(input: ParseStream) -> Result<Self, Error> {
let attributes = input.call(tolerant_outer_attributes)?;
let visibility = input.call(syn::Visibility::parse)?;
let constness: Option<syn::Token![const]> = input.parse()?;
let asyncness: Option<syn::Token![async]> = input.parse()?;
let unsafety: Option<syn::Token![unsafe]> = input.parse()?;
let abi: Option<syn::Abi> = input.parse()?;
let fn_token: syn::Token![fn] = input.parse()?;
let ident: Ident = input.parse()?;
let generics: syn::Generics = input.parse()?;
let content;
let paren_token = syn::parenthesized!(content in input);
let delegated_inputs = content.parse_terminated(DelegatedInput::parse, syn::Token![,])?;
let mut inputs: syn::punctuated::Punctuated<syn::FnArg, syn::Token![,]> =
syn::punctuated::Punctuated::new();
let mut arguments: syn::punctuated::Punctuated<syn::Expr, syn::Token![,]> =
syn::punctuated::Punctuated::new();
delegated_inputs
.into_pairs()
.map(|punctuated_pair| match punctuated_pair {
syn::punctuated::Pair::Punctuated(item, comma) => (item, Some(comma)),
syn::punctuated::Pair::End(item) => (item, None),
})
.for_each(|pair| match pair {
(DelegatedInput::Argument(argument), maybe_comma) => {
arguments.push_value(argument);
if let Some(comma) = maybe_comma {
arguments.push_punct(comma)
}
}
(
DelegatedInput::Input {
parameter,
modifier,
},
maybe_comma,
) => {
inputs.push_value(parameter.clone());
if let Some(comma) = maybe_comma {
inputs.push_punct(comma);
}
let maybe_argument = parse_input_into_argument_expression(&ident, ¶meter);
if let Some(mut argument) = maybe_argument {
let span = argument.span();
if let Some(modifier) = modifier {
let method_call = |name: &str| {
syn::Expr::from(ExprMethodCall {
attrs: vec![],
receiver: Box::new(argument.clone()),
dot_token: Default::default(),
method: Ident::new(name, span),
turbofish: None,
paren_token,
args: Default::default(),
})
};
let field_call = || {
syn::Expr::from(ExprField {
attrs: vec![],
base: Box::new(argument.clone()),
dot_token: Default::default(),
member: syn::Member::Unnamed(0.into()),
})
};
match modifier {
ArgumentModifier::Into => {
argument = method_call("into");
}
ArgumentModifier::AsRef => {
argument = method_call("as_ref");
}
ArgumentModifier::Newtype => argument = field_call(),
}
}
arguments.push(argument);
if let Some(comma) = maybe_comma {
arguments.push_punct(comma);
}
}
}
});
let output: syn::ReturnType = input.parse()?;
let where_clause: Option<syn::WhereClause> = input.parse()?;
let signature = syn::Signature {
constness,
asyncness,
unsafety,
abi,
fn_token,
ident,
paren_token,
inputs,
output,
variadic: None,
generics: syn::Generics {
where_clause,
..generics
},
};
let lookahead = input.lookahead1();
let semi_token: Option<syn::Token![;]> = if lookahead.peek(syn::Token![;]) {
Some(input.parse()?)
} else {
panic!(
"Do not include implementation of delegated functions ({})",
signature.ident
);
};
let method = syn::TraitItemFn {
attrs: Vec::new(),
sig: signature,
default: None,
semi_token,
};
Ok(DelegatedMethod {
method,
attributes,
visibility,
arguments,
})
}
}
struct DelegatedSegment {
delegator: syn::Expr,
methods: Vec<DelegatedMethod>,
segment_attrs: SegmentAttributes,
}
impl syn::parse::Parse for DelegatedSegment {
fn parse(input: ParseStream) -> Result<Self, Error> {
let attributes = input.call(tolerant_outer_attributes)?;
let segment_attrs = parse_segment_attributes(&attributes);
if let Ok(keyword) = input.parse::<kw::target>() {
return Err(Error::new(keyword.span(), "You are using the old `target` expression, which is deprecated. Please replace `target` with `to`."));
} else {
input.parse::<kw::to>()?;
}
syn::Expr::parse_without_eager_brace(input).and_then(|delegator| {
let content;
syn::braced!(content in input);
let mut methods = vec![];
while !content.is_empty() {
methods.push(
content
.parse::<DelegatedMethod>()
.expect("Cannot parse delegated method"),
);
}
Ok(DelegatedSegment {
delegator,
methods,
segment_attrs,
})
})
}
}
struct DelegationBlock {
segments: Vec<DelegatedSegment>,
}
impl syn::parse::Parse for DelegationBlock {
fn parse(input: ParseStream) -> Result<Self, Error> {
let mut segments = vec![];
while !input.is_empty() {
segments.push(input.parse()?);
}
Ok(DelegationBlock { segments })
}
}
fn has_inline_attribute(attrs: &[&syn::Attribute]) -> bool {
attrs.iter().any(|attr| {
if let syn::AttrStyle::Outer = attr.style {
attr.path().is_ident("inline")
} else {
false
}
})
}
struct MatchVisitor<F>(F);
impl<F: Fn(&Expr) -> proc_macro2::TokenStream> VisitMut for MatchVisitor<F> {
fn visit_arm_mut(&mut self, arm: &mut syn::Arm) {
let transformed = self.0(&arm.body);
arm.body = parse_quote!(#transformed);
}
}
#[proc_macro]
pub fn delegate(tokens: TokenStream) -> TokenStream {
let block: DelegationBlock = syn::parse_macro_input!(tokens);
let sections = block.segments.iter().map(|delegator| {
let delegated_expr = &delegator.delegator;
let functions = delegator.methods.iter().map(|method| {
let input = &method.method;
let mut signature = input.sig.clone();
if let Expr::Closure(closure) = delegated_expr {
let additional_inputs: Vec<FnArg> = closure
.inputs
.iter()
.map(|input| {
if let syn::Pat::Type(pat_type) = input {
syn::parse_quote!(#pat_type)
} else {
panic!(
"Use a type pattern (`a: u32`) for delegation closure arguments"
);
}
})
.collect();
let mut origin_inputs = mem::take(&mut signature.inputs).into_iter();
let first_input = origin_inputs.next();
match first_input {
Some(FnArg::Receiver(receiver)) => {
signature.inputs.push(FnArg::Receiver(receiver));
signature.inputs.extend(additional_inputs);
}
Some(first_input) => {
signature.inputs.extend(additional_inputs);
signature.inputs.push(first_input);
}
_ => {
signature.inputs.extend(additional_inputs);
}
}
signature.inputs.extend(origin_inputs);
}
let attributes = parse_method_attributes(&method.attributes, input);
let attributes = combine_attributes(attributes, &delegator.segment_attrs);
if input.default.is_some() {
panic!(
"Do not include implementation of delegated functions ({})",
signature.ident
);
}
let args: Vec<Expr> = method.arguments.clone().into_iter().collect();
let name = match &attributes.target_specifier {
Some(target) => target.get_member(&input.sig.ident),
None => input.sig.ident.clone().into(),
};
let inline = if has_inline_attribute(&attributes.attributes) {
quote!()
} else {
quote! { #[inline] }
};
let visibility = &method.visibility;
let is_method = method.method.sig.receiver().is_some();
let associated_const = &attributes.associated_constant;
let expr_attr = &attributes.expr_attr;
let delegated_body = if let Expr::Closure(closure) = delegated_expr {
&closure.body
} else {
delegated_expr
};
let span = input.span();
let generate_await = attributes
.generate_await
.unwrap_or_else(|| method.method.sig.asyncness.is_some());
let generic_params = &method.method.sig.generics.params;
let generics = if generic_params.is_empty() {
quote::quote! {}
} else {
let span = generic_params.span();
let mut params: syn::punctuated::Punctuated<
proc_macro2::TokenStream,
syn::Token![,],
> = syn::punctuated::Punctuated::new();
for param in generic_params.iter() {
let token = match param {
GenericParam::Lifetime(_) => {
continue;
}
GenericParam::Type(t) => {
let token = &t.ident;
let span = t.span();
quote::quote_spanned! {span=> #token }
}
GenericParam::Const(c) => {
let token = &c.ident;
let span = c.span();
quote::quote_spanned! {span=> #token }
}
};
params.push(token);
}
quote::quote_spanned! {span=> ::<#params> }
};
let modify_expr = |expr: &Expr| {
let body = if let Some(target_trait) = &attributes.target_trait {
quote::quote! { #target_trait::#name#generics(#expr, #(#args),*) }
} else if let Some(AssociatedConstant {
const_name,
trait_path,
}) = associated_const
{
let return_type = &signature.output;
quote::quote! {{
const fn get_const<T: #trait_path>(t: &T) #return_type {
<T as #trait_path>::#const_name
}
get_const(#expr)
}}
} else if is_method {
match &attributes.target_specifier {
None | Some(TargetSpecifier::Method(_)) => {
quote::quote! { #expr.#name#generics(#(#args),*) }
}
Some(TargetSpecifier::Field(target)) => {
let reference = target.reference_tokens();
quote::quote! { #reference#expr.#name }
}
}
} else {
quote::quote! { #expr::#name#generics(#(#args),*) }
};
let mut body = if generate_await {
quote::quote! { #body.await }
} else {
body
};
for expression in &attributes.expressions {
match expression {
ReturnExpression::Into(type_name) => {
body = match type_name {
Some(name) => {
quote::quote! { ::core::convert::Into::<#name>::into(#body) }
}
None => quote::quote! { ::core::convert::Into::into(#body) },
};
}
ReturnExpression::TryInto => {
body = quote::quote! { ::core::convert::TryInto::try_into(#body) };
}
ReturnExpression::Unwrap => {
body = quote::quote! { #body.unwrap() };
}
}
}
body
};
let mut body = if let Expr::Match(expr_match) = delegated_body {
let mut expr_match = expr_match.clone();
MatchVisitor(modify_expr).visit_expr_match_mut(&mut expr_match);
expr_match.into_token_stream()
} else {
modify_expr(delegated_body)
};
if let syn::ReturnType::Default = &signature.output {
body = quote::quote! { #body; };
};
if let Some(expr_template) = expr_attr {
body = expr_template.expand_template(&body);
}
let attrs = &attributes.attributes;
quote::quote_spanned! {span=>
#(#attrs)*
#inline
#visibility #signature {
#body
}
}
});
quote! { #(#functions)* }
});
let result = quote! {
#(#sections)*
};
result.into()
}
fn tolerant_outer_attributes(input: ParseStream) -> syn::Result<Vec<syn::Attribute>> {
use proc_macro2::{Delimiter, TokenTree};
use syn::{
bracketed,
ext::IdentExt,
parse::discouraged::Speculative,
token::{Brace, Bracket, Paren},
AttrStyle, Attribute, ExprLit, Lit, MacroDelimiter, MetaList, MetaNameValue, Path, Result,
Token,
};
fn tolerant_attr(input: ParseStream) -> Result<Attribute> {
let content;
Ok(Attribute {
pound_token: input.parse()?,
style: AttrStyle::Outer,
bracket_token: bracketed!(content in input),
meta: content.call(tolerant_meta)?,
})
}
fn tolerant_meta(input: ParseStream) -> Result<Meta> {
if let Ok(meta) = input.call(Meta::parse) {
Ok(meta)
} else {
let path = Path::from(input.call(Ident::parse_any)?);
if input.peek(Paren) || input.peek(Bracket) || input.peek(Brace) {
input.step(|cursor| {
if let Some((TokenTree::Group(g), rest)) = cursor.token_tree() {
let span = g.delim_span();
let delimiter = match g.delimiter() {
Delimiter::Parenthesis => MacroDelimiter::Paren(Paren(span)),
Delimiter::Brace => MacroDelimiter::Brace(Brace(span)),
Delimiter::Bracket => MacroDelimiter::Bracket(Bracket(span)),
Delimiter::None => {
return Err(cursor.error("expected delimiter"));
}
};
Ok((
Meta::List(MetaList {
path,
delimiter,
tokens: g.stream(),
}),
rest,
))
} else {
Err(cursor.error("expected delimiter"))
}
})
} else if input.peek(Token![=]) {
let eq_token = input.parse()?;
let ahead = input.fork();
let value = match ahead.parse::<Option<Lit>>()? {
Some(lit) if ahead.is_empty() => {
input.advance_to(&ahead);
Expr::Lit(ExprLit {
attrs: Vec::new(),
lit,
})
}
_ => input.parse()?,
};
Ok(Meta::NameValue(MetaNameValue {
path,
eq_token,
value,
}))
} else {
Ok(Meta::Path(path))
}
}
}
let mut attrs = Vec::new();
while input.peek(Token![#]) {
attrs.push(input.call(tolerant_attr)?);
}
Ok(attrs)
}