#![recursion_limit = "128"]
extern crate proc_macro;
use proc_macro2;
#[macro_use]
extern crate syn;
#[macro_use]
extern crate quote;
use proc_macro::TokenStream;
use proc_macro2::Span;
use syn::{Data, Field, Fields, DeriveInput, Ident, parse::Error, spanned::Spanned};
use proc_macro_crate::crate_name;
use std::env;
mod decode;
mod encode;
mod utils;
mod trait_bounds;
fn include_parity_scale_codec_crate() -> proc_macro2::TokenStream {
if env::var("CARGO_PKG_NAME").unwrap() == "parity-scale-codec" {
quote!( extern crate parity_scale_codec as _parity_scale_codec; )
} else {
match crate_name("parity-scale-codec") {
Ok(parity_codec_crate) => {
let ident = Ident::new(&parity_codec_crate, Span::call_site());
quote!( extern crate #ident as _parity_scale_codec; )
},
Err(e) => Error::new(Span::call_site(), &e).to_compile_error(),
}
}
}
fn wrap_with_dummy_const(input: &DeriveInput, prefix: &str, impl_block: proc_macro2::TokenStream) -> TokenStream {
let parity_codec_crate = include_parity_scale_codec_crate();
let mut new_name = prefix.to_string();
new_name.push_str(input.ident.to_string().trim_start_matches("r#"));
let dummy_const = Ident::new(&new_name, Span::call_site());
let generated = quote! {
#[allow(non_upper_case_globals, unused_attributes, unused_qualifications)]
const #dummy_const: () = {
#[allow(unknown_lints)]
#[cfg_attr(feature = "cargo-clippy", allow(useless_attribute))]
#[allow(rust_2018_idioms)]
#parity_codec_crate
#impl_block
};
};
generated.into()
}
#[proc_macro_derive(Encode, attributes(codec))]
pub fn encode_derive(input: TokenStream) -> TokenStream {
let mut input: DeriveInput = match syn::parse(input) {
Ok(input) => input,
Err(e) => return e.to_compile_error().into(),
};
if let Some(span) = utils::get_skip(&input.attrs) {
return Error::new(span, "invalid attribute `skip` on root input")
.to_compile_error().into();
}
if let Err(e) = trait_bounds::add(
&input.ident,
&mut input.generics,
&input.data,
parse_quote!(_parity_scale_codec::Encode),
None,
) {
return e.to_compile_error().into();
}
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let encode_impl = encode::quote(&input.data, name);
let impl_block = quote! {
impl #impl_generics _parity_scale_codec::Encode for #name #ty_generics #where_clause {
#encode_impl
}
};
wrap_with_dummy_const(&input, "_IMPL_ENCODE_FOR_", impl_block)
}
#[proc_macro_derive(Decode, attributes(codec))]
pub fn decode_derive(input: TokenStream) -> TokenStream {
let mut input: DeriveInput = match syn::parse(input) {
Ok(input) => input,
Err(e) => return e.to_compile_error().into(),
};
if let Some(span) = utils::get_skip(&input.attrs) {
return Error::new(span, "invalid attribute `skip` on root input")
.to_compile_error().into();
}
if let Err(e) = trait_bounds::add(
&input.ident,
&mut input.generics,
&input.data,
parse_quote!(_parity_scale_codec::Decode),
Some(parse_quote!(Default))
) {
return e.to_compile_error().into();
}
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let input_ = quote!(input);
let decoding = decode::quote(&input.data, name, &input_);
let impl_block = quote! {
impl #impl_generics _parity_scale_codec::Decode for #name #ty_generics #where_clause {
fn decode<DecIn: _parity_scale_codec::Input>(
#input_: &mut DecIn
) -> core::result::Result<Self, _parity_scale_codec::Error> {
#decoding
}
}
};
wrap_with_dummy_const(&input, "_IMPL_DECODE_FOR_", impl_block)
}
#[proc_macro_derive(CompactAs, attributes(codec))]
pub fn compact_as_derive(input: TokenStream) -> TokenStream {
let mut input: DeriveInput = match syn::parse(input) {
Ok(input) => input,
Err(e) => return e.to_compile_error().into(),
};
if let Err(e) = trait_bounds::add(
&input.ident,
&mut input.generics,
&input.data,
parse_quote!(_parity_scale_codec::CompactAs),
None,
) {
return e.to_compile_error().into();
}
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
fn val_or_default(field: &Field) -> proc_macro2::TokenStream {
let skip = utils::get_skip(&field.attrs).is_some();
if skip {
quote_spanned!(field.span()=> Default::default())
} else {
quote_spanned!(field.span()=> x)
}
}
let (inner_ty, inner_field, constructor) = match input.data {
Data::Struct(ref data) => {
match data.fields {
Fields::Named(ref fields) if utils::filter_skip_named(fields).count() == 1 => {
let recurse = fields.named.iter().map(|f| {
let name_ident = &f.ident;
let val_or_default = val_or_default(&f);
quote_spanned!(f.span()=> #name_ident: #val_or_default)
});
let field = utils::filter_skip_named(fields).next().expect("Exactly one field");
let field_name = &field.ident;
let constructor = quote!( #name { #( #recurse, )* });
(&field.ty, quote!(&self.#field_name), constructor)
},
Fields::Unnamed(ref fields) if utils::filter_skip_unnamed(fields).count() == 1 => {
let recurse = fields.unnamed.iter().enumerate().map(|(_, f) | {
let val_or_default = val_or_default(&f);
quote_spanned!(f.span()=> #val_or_default)
});
let (id, field) = utils::filter_skip_unnamed(fields).next().expect("Exactly one field");
let id = syn::Index::from(id);
let constructor = quote!( #name(#( #recurse, )*));
(&field.ty, quote!(&self.#id), constructor)
},
_ => {
return Error::new(
data.fields.span(),
"Only structs with a single non-skipped field can derive CompactAs"
).to_compile_error().into();
},
}
},
Data::Enum(syn::DataEnum { enum_token: syn::token::Enum { span }, .. }) |
Data::Union(syn::DataUnion { union_token: syn::token::Union { span }, .. }) => {
return Error::new(span, "Only structs can derive CompactAs").to_compile_error().into();
},
};
let impl_block = quote! {
impl #impl_generics _parity_scale_codec::CompactAs for #name #ty_generics #where_clause {
type As = #inner_ty;
fn encode_as(&self) -> &#inner_ty {
#inner_field
}
fn decode_from(x: #inner_ty) -> #name #ty_generics {
#constructor
}
}
impl #impl_generics From<_parity_scale_codec::Compact<#name #ty_generics>> for #name #ty_generics #where_clause {
fn from(x: _parity_scale_codec::Compact<#name #ty_generics>) -> #name #ty_generics {
x.0
}
}
};
wrap_with_dummy_const(&input, "_IMPL_COMPACTAS_FOR_", impl_block)
}