parity_scale_codec_derive/
lib.rs

1// Copyright 2017-2021 Parity Technologies
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Derives serialization and deserialization codec for complex structs for simple marshalling.
16
17#![recursion_limit = "128"]
18extern crate proc_macro;
19
20#[macro_use]
21extern crate syn;
22
23#[macro_use]
24extern crate quote;
25
26use crate::utils::{codec_crate_path, is_lint_attribute};
27use syn::{spanned::Spanned, Data, DeriveInput, Error, Field, Fields};
28
29mod decode;
30mod encode;
31mod max_encoded_len;
32mod trait_bounds;
33mod utils;
34
35/// Wraps the impl block in a "dummy const"
36fn wrap_with_dummy_const(
37	input: DeriveInput,
38	impl_block: proc_macro2::TokenStream,
39) -> proc_macro::TokenStream {
40	let attrs = input.attrs.into_iter().filter(is_lint_attribute);
41	let generated = quote! {
42		#[allow(deprecated)]
43		const _: () = {
44			#(#attrs)*
45			#impl_block
46		};
47	};
48
49	generated.into()
50}
51
52/// Derive `parity_scale_codec::Encode` and `parity_scale_codec::EncodeLike` for struct and enum.
53///
54/// # Top level attributes
55///
56/// By default the macro will add [`Encode`] and [`Decode`] bounds to all types, but the bounds can
57/// be specified manually with the top level attributes:
58/// * `#[codec(encode_bound(T: Encode))]`: a custom bound added to the `where`-clause when deriving
59///   the `Encode` trait, overriding the default.
60/// * `#[codec(decode_bound(T: Decode))]`: a custom bound added to the `where`-clause when deriving
61///   the `Decode` trait, overriding the default.
62///
63/// # Struct
64///
65/// A struct is encoded by encoding each of its fields successively.
66///
67/// Fields can have some attributes:
68/// * `#[codec(skip)]`: the field is not encoded. It must derive `Default` if Decode is derived.
69/// * `#[codec(compact)]`: the field is encoded in its compact representation i.e. the field must
70///   implement `parity_scale_codec::HasCompact` and will be encoded as `HasCompact::Type`.
71/// * `#[codec(encoded_as = "$EncodeAs")]`: the field is encoded as an alternative type. $EncodedAs
72///   type must implement `parity_scale_codec::EncodeAsRef<'_, $FieldType>` with $FieldType the type
73///   of the field with the attribute. This is intended to be used for types implementing
74///   `HasCompact` as shown in the example.
75///
76/// ```
77/// # use parity_scale_codec_derive::Encode;
78/// # use parity_scale_codec::{Encode as _, HasCompact};
79/// #[derive(Encode)]
80/// struct StructType {
81///     #[codec(skip)]
82///     a: u32,
83///     #[codec(compact)]
84///     b: u32,
85///     #[codec(encoded_as = "<u32 as HasCompact>::Type")]
86///     c: u32,
87/// }
88/// ```
89///
90/// # Enum
91///
92/// The variable is encoded with one byte for the variant and then the variant struct encoding.
93/// The variant number is:
94/// * if variant has attribute: `#[codec(index = "$n")]` then n
95/// * else if variant has discriminant (like 3 in `enum T { A = 3 }`) then the discriminant.
96/// * else its position in the variant set, excluding skipped variants, but including variant with
97///   discriminant or attribute. Warning this position does collision with discriminant or attribute
98///   index.
99///
100/// variant attributes:
101/// * `#[codec(skip)]`: the variant is not encoded.
102/// * `#[codec(index = "$n")]`: override variant index.
103///
104/// field attributes: same as struct fields attributes.
105///
106/// ```
107/// # use parity_scale_codec_derive::Encode;
108/// # use parity_scale_codec::Encode as _;
109/// #[derive(Encode)]
110/// enum EnumType {
111///     #[codec(index = 15)]
112///     A,
113///     #[codec(skip)]
114///     B,
115///     C = 3,
116///     D,
117/// }
118///
119/// assert_eq!(EnumType::A.encode(), vec![15]);
120/// assert_eq!(EnumType::B.encode(), vec![]);
121/// assert_eq!(EnumType::C.encode(), vec![3]);
122/// assert_eq!(EnumType::D.encode(), vec![2]);
123/// ```
124#[proc_macro_derive(Encode, attributes(codec))]
125pub fn encode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
126	let mut input: DeriveInput = match syn::parse(input) {
127		Ok(input) => input,
128		Err(e) => return e.to_compile_error().into(),
129	};
130
131	if let Err(e) = utils::check_attributes(&input) {
132		return e.to_compile_error().into();
133	}
134
135	let crate_path = match codec_crate_path(&input.attrs) {
136		Ok(crate_path) => crate_path,
137		Err(error) => return error.into_compile_error().into(),
138	};
139
140	if let Err(e) = trait_bounds::add(
141		&input.ident,
142		&mut input.generics,
143		&input.data,
144		utils::custom_encode_trait_bound(&input.attrs),
145		parse_quote!(#crate_path::Encode),
146		None,
147		utils::has_dumb_trait_bound(&input.attrs),
148		&crate_path,
149		false,
150	) {
151		return e.to_compile_error().into();
152	}
153
154	let name = &input.ident;
155	let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
156
157	let encode_impl = encode::quote(&input.data, name, &crate_path);
158
159	let impl_block = quote! {
160		#[automatically_derived]
161		impl #impl_generics #crate_path::Encode for #name #ty_generics #where_clause {
162			#encode_impl
163		}
164
165		#[automatically_derived]
166		impl #impl_generics #crate_path::EncodeLike for #name #ty_generics #where_clause {}
167	};
168
169	wrap_with_dummy_const(input, impl_block)
170}
171
172/// Derive `parity_scale_codec::Decode` for struct and enum.
173///
174/// see derive `Encode` documentation.
175#[proc_macro_derive(Decode, attributes(codec))]
176pub fn decode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
177	let mut input: DeriveInput = match syn::parse(input) {
178		Ok(input) => input,
179		Err(e) => return e.to_compile_error().into(),
180	};
181
182	if let Err(e) = utils::check_attributes(&input) {
183		return e.to_compile_error().into();
184	}
185
186	let crate_path = match codec_crate_path(&input.attrs) {
187		Ok(crate_path) => crate_path,
188		Err(error) => return error.into_compile_error().into(),
189	};
190
191	if let Err(e) = trait_bounds::add(
192		&input.ident,
193		&mut input.generics,
194		&input.data,
195		utils::custom_decode_trait_bound(&input.attrs),
196		parse_quote!(#crate_path::Decode),
197		Some(parse_quote!(Default)),
198		utils::has_dumb_trait_bound(&input.attrs),
199		&crate_path,
200		false,
201	) {
202		return e.to_compile_error().into();
203	}
204
205	let name = &input.ident;
206	let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
207	let ty_gen_turbofish = ty_generics.as_turbofish();
208
209	let input_ = quote!(__codec_input_edqy);
210	let decoding =
211		decode::quote(&input.data, name, &quote!(#ty_gen_turbofish), &input_, &crate_path);
212
213	let decode_into_body =
214		decode::quote_decode_into(&input.data, &crate_path, &input_, &input.attrs);
215
216	let impl_decode_into = if let Some(body) = decode_into_body {
217		quote! {
218			fn decode_into<__CodecInputEdqy: #crate_path::Input>(
219				#input_: &mut __CodecInputEdqy,
220				dst_: &mut ::core::mem::MaybeUninit<Self>,
221			) -> ::core::result::Result<#crate_path::DecodeFinished, #crate_path::Error> {
222				#body
223			}
224		}
225	} else {
226		quote! {}
227	};
228
229	let impl_block = quote! {
230		#[automatically_derived]
231		impl #impl_generics #crate_path::Decode for #name #ty_generics #where_clause {
232			fn decode<__CodecInputEdqy: #crate_path::Input>(
233				#input_: &mut __CodecInputEdqy
234			) -> ::core::result::Result<Self, #crate_path::Error> {
235				#decoding
236			}
237
238			#impl_decode_into
239		}
240	};
241
242	wrap_with_dummy_const(input, impl_block)
243}
244
245/// Derive `parity_scale_codec::DecodeWithMemTracking` for struct and enum.
246#[proc_macro_derive(DecodeWithMemTracking, attributes(codec))]
247pub fn decode_with_mem_tracking_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
248	let mut input: DeriveInput = match syn::parse(input) {
249		Ok(input) => input,
250		Err(e) => return e.to_compile_error().into(),
251	};
252
253	if let Err(e) = utils::check_attributes(&input) {
254		return e.to_compile_error().into();
255	}
256
257	let crate_path = match codec_crate_path(&input.attrs) {
258		Ok(crate_path) => crate_path,
259		Err(error) => return error.into_compile_error().into(),
260	};
261
262	if let Err(e) = trait_bounds::add(
263		&input.ident,
264		&mut input.generics,
265		&input.data,
266		utils::custom_decode_with_mem_tracking_trait_bound(&input.attrs),
267		parse_quote!(#crate_path::DecodeWithMemTracking),
268		Some(parse_quote!(Default)),
269		utils::has_dumb_trait_bound(&input.attrs),
270		&crate_path,
271		true,
272	) {
273		return e.to_compile_error().into();
274	}
275
276	let name = &input.ident;
277	let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
278
279	let decode_with_mem_tracking_checks =
280		decode::quote_decode_with_mem_tracking_checks(&input.data, &crate_path);
281	let impl_block = quote! {
282		fn check_struct #impl_generics() #where_clause {
283			#decode_with_mem_tracking_checks
284		}
285
286		#[automatically_derived]
287		impl #impl_generics #crate_path::DecodeWithMemTracking for #name #ty_generics #where_clause {
288		}
289	};
290
291	wrap_with_dummy_const(input, impl_block)
292}
293
294/// Derive `parity_scale_codec::Compact` and `parity_scale_codec::CompactAs` for struct with single
295/// field.
296///
297/// Attribute skip can be used to skip other fields.
298///
299/// # Example
300///
301/// ```
302/// # use parity_scale_codec_derive::CompactAs;
303/// # use parity_scale_codec::{Encode, HasCompact};
304/// # use std::marker::PhantomData;
305/// #[derive(CompactAs)]
306/// struct MyWrapper<T>(u32, #[codec(skip)] PhantomData<T>);
307/// ```
308#[proc_macro_derive(CompactAs, attributes(codec))]
309pub fn compact_as_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
310	let mut input: DeriveInput = match syn::parse(input) {
311		Ok(input) => input,
312		Err(e) => return e.to_compile_error().into(),
313	};
314
315	if let Err(e) = utils::check_attributes(&input) {
316		return e.to_compile_error().into();
317	}
318
319	let crate_path = match codec_crate_path(&input.attrs) {
320		Ok(crate_path) => crate_path,
321		Err(error) => return error.into_compile_error().into(),
322	};
323
324	if let Err(e) = trait_bounds::add::<()>(
325		&input.ident,
326		&mut input.generics,
327		&input.data,
328		None,
329		parse_quote!(#crate_path::CompactAs),
330		None,
331		utils::has_dumb_trait_bound(&input.attrs),
332		&crate_path,
333		false,
334	) {
335		return e.to_compile_error().into();
336	}
337
338	let name = &input.ident;
339	let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
340
341	fn val_or_default(field: &Field) -> proc_macro2::TokenStream {
342		let skip = utils::should_skip(&field.attrs);
343		if skip {
344			quote_spanned!(field.span()=> Default::default())
345		} else {
346			quote_spanned!(field.span()=> x)
347		}
348	}
349
350	let (inner_ty, inner_field, constructor) = match input.data {
351		Data::Struct(ref data) => match data.fields {
352			Fields::Named(ref fields) if utils::filter_skip_named(fields).count() == 1 => {
353				let recurse = fields.named.iter().map(|f| {
354					let name_ident = &f.ident;
355					let val_or_default = val_or_default(f);
356					quote_spanned!(f.span()=> #name_ident: #val_or_default)
357				});
358				let field = utils::filter_skip_named(fields).next().expect("Exactly one field");
359				let field_name = &field.ident;
360				let constructor = quote!( #name { #( #recurse, )* });
361				(&field.ty, quote!(&self.#field_name), constructor)
362			},
363			Fields::Unnamed(ref fields) if utils::filter_skip_unnamed(fields).count() == 1 => {
364				let recurse = fields.unnamed.iter().map(|f| {
365					let val_or_default = val_or_default(f);
366					quote_spanned!(f.span()=> #val_or_default)
367				});
368				let (id, field) =
369					utils::filter_skip_unnamed(fields).next().expect("Exactly one field");
370				let id = syn::Index::from(id);
371				let constructor = quote!( #name(#( #recurse, )*));
372				(&field.ty, quote!(&self.#id), constructor)
373			},
374			_ =>
375				return Error::new(
376					data.fields.span(),
377					"Only structs with a single non-skipped field can derive CompactAs",
378				)
379				.to_compile_error()
380				.into(),
381		},
382		Data::Enum(syn::DataEnum { enum_token: syn::token::Enum { span }, .. }) |
383		Data::Union(syn::DataUnion { union_token: syn::token::Union { span }, .. }) =>
384			return Error::new(span, "Only structs can derive CompactAs").to_compile_error().into(),
385	};
386
387	let impl_block = quote! {
388		#[automatically_derived]
389		impl #impl_generics #crate_path::CompactAs for #name #ty_generics #where_clause {
390			type As = #inner_ty;
391			fn encode_as(&self) -> &#inner_ty {
392				#inner_field
393			}
394			fn decode_from(x: #inner_ty)
395				-> ::core::result::Result<#name #ty_generics, #crate_path::Error>
396			{
397				::core::result::Result::Ok(#constructor)
398			}
399		}
400
401		#[automatically_derived]
402		impl #impl_generics From<#crate_path::Compact<#name #ty_generics>>
403			for #name #ty_generics #where_clause
404		{
405			fn from(x: #crate_path::Compact<#name #ty_generics>) -> #name #ty_generics {
406				x.0
407			}
408		}
409	};
410
411	wrap_with_dummy_const(input, impl_block)
412}
413
414/// Derive `parity_scale_codec::MaxEncodedLen` for struct and enum.
415///
416/// # Top level attribute
417///
418/// By default the macro will try to bound the types needed to implement `MaxEncodedLen`, but the
419/// bounds can be specified manually with the top level attribute:
420/// ```
421/// # use parity_scale_codec_derive::Encode;
422/// # use parity_scale_codec::MaxEncodedLen;
423/// # #[derive(Encode, MaxEncodedLen)]
424/// #[codec(mel_bound(T: MaxEncodedLen))]
425/// # struct MyWrapper<T>(T);
426/// ```
427#[cfg(feature = "max-encoded-len")]
428#[proc_macro_derive(MaxEncodedLen, attributes(max_encoded_len_mod))]
429pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
430	max_encoded_len::derive_max_encoded_len(input)
431}