1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
// Copyright 2017, 2018 Parity Technologies
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Derives serialization and deserialization codec for complex structs for simple marshalling.

#![cfg_attr(not(feature = "std"), no_std)]

extern crate proc_macro;
extern crate proc_macro2;

#[macro_use]
extern crate syn;

#[macro_use]
extern crate quote;

use proc_macro::TokenStream;
use syn::{DeriveInput, Generics, GenericParam, Ident};

mod decode;
mod encode;

const ENCODE_ERR: &str = "derive(Encode) failed";

#[proc_macro_derive(Encode, attributes(codec))]
pub fn encode_derive(input: TokenStream) -> TokenStream {
	let input: DeriveInput = syn::parse(input).expect(ENCODE_ERR);
	let name = &input.ident;

	let generics = add_trait_bounds(input.generics, parse_quote!(::codec::Encode));
	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

	let self_ = quote!(self);
	let dest_ = quote!(dest);
	let encoding = encode::quote(&input.data, name, &self_, &dest_);

	let expanded = quote! {
		impl #impl_generics ::codec::Encode for #name #ty_generics #where_clause {
			fn encode_to<EncOut: ::codec::Output>(&#self_, #dest_: &mut EncOut) {
				#encoding
			}
		}
	};

	expanded.into()
}

#[proc_macro_derive(Decode, attributes(codec))]
pub fn decode_derive(input: TokenStream) -> TokenStream {
	let input: DeriveInput = syn::parse(input).expect(ENCODE_ERR);
	let name = &input.ident;

	let generics = add_trait_bounds(input.generics, parse_quote!(::codec::Decode));
	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

	let input_ = quote!(input);
	let decoding = decode::quote(&input.data, name, &input_);

	let expanded = quote! {
		impl #impl_generics ::codec::Decode for #name #ty_generics #where_clause {
			fn decode<DecIn: ::codec::Input>(#input_: &mut DecIn) -> Option<Self> {
				#decoding
			}
		}
	};

	expanded.into()
}

fn add_trait_bounds(mut generics: Generics, bounds: syn::TypeParamBound) -> Generics {
	for param in &mut generics.params {
		if let GenericParam::Type(ref mut type_param) = *param {
			type_param.bounds.push(bounds.clone());
		}
	}
	generics
}

fn index(v: &syn::Variant, i: usize) -> proc_macro2::TokenStream {
	// look for an index in attributes
	let index = v.attrs.iter().filter_map(|attr| {
		let pair = attr.path.segments.first()?;
		let seg = pair.value();

		if seg.ident == Ident::new("codec", seg.ident.span()) {
			assert_eq!(attr.path.segments.len(), 1);

			let meta = attr.interpret_meta();
			if let Some(syn::Meta::List(ref l)) = meta {
				if let syn::NestedMeta::Meta(syn::Meta::NameValue(ref nv)) = l.nested.last().unwrap().value() {
					assert_eq!(nv.ident, Ident::new("index", nv.ident.span()));
					if let syn::Lit::Str(ref s) = nv.lit {
						let byte: u8 = s.value().parse().expect("Numeric index expected.");
						return Some(byte)
					}
					panic!("Invalid syntax for `codec` attribute: Expected string literal.")
				}
			}
			panic!("Invalid syntax for `codec` attribute: Expected `name = value` pair.")
		} else {
			None
		}
	}).next();

	// then fallback to discriminant or just index
	index.map(|i| quote! { #i })
		.unwrap_or_else(|| v.discriminant
			.as_ref()
			.map(|&(_, ref expr)| quote! { #expr })
			.unwrap_or_else(|| quote! { #i })
		)
}