1#![allow(clippy::manual_unwrap_or_default)]
19
20extern crate alloc;
21
22use alloc::string::ToString;
23use darling::FromAttributes;
24use proc_macro2::{Span, TokenStream as TokenStream2};
25use quote::quote;
26use syn::{parse_macro_input, punctuated::Punctuated, DeriveInput};
27
28const ATTR_NAME: &str = "decode_as_type";
29
30#[proc_macro_derive(DecodeAsType, attributes(decode_as_type, codec))]
32pub fn derive_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
33 let input = parse_macro_input!(input as DeriveInput);
34
35 let attrs = match TopLevelAttrs::parse(&input.attrs) {
37 Ok(attrs) => attrs,
38 Err(e) => return e.write_errors().into(),
39 };
40
41 derive_with_attrs(attrs, input).into()
42}
43
44fn derive_with_attrs(attrs: TopLevelAttrs, input: DeriveInput) -> TokenStream2 {
45 let visibility = &input.vis;
46 match &input.data {
48 syn::Data::Enum(details) => generate_enum_impl(attrs, visibility, &input, details),
49 syn::Data::Struct(details) => generate_struct_impl(attrs, visibility, &input, details),
50 syn::Data::Union(_) => syn::Error::new(
51 input.ident.span(),
52 "Unions are not supported by the DecodeAsType macro",
53 )
54 .into_compile_error(),
55 }
56}
57
58fn generate_enum_impl(
59 attrs: TopLevelAttrs,
60 visibility: &syn::Visibility,
61 input: &DeriveInput,
62 details: &syn::DataEnum,
63) -> TokenStream2 {
64 let path_to_scale_decode = &attrs.crate_path;
65 let path_to_type: syn::Path = input.ident.clone().into();
66 let variant_names = details.variants.iter().map(|v| v.ident.to_string());
67
68 let generic_types = handle_generics(&attrs, input.generics.clone());
69 let ty_generics = generic_types.ty_generics();
70 let impl_generics = generic_types.impl_generics();
71 let visitor_where_clause = generic_types.visitor_where_clause();
72 let visitor_ty_generics = generic_types.visitor_ty_generics();
73 let visitor_impl_generics = generic_types.visitor_impl_generics();
74 let visitor_phantomdata_type = generic_types.visitor_phantomdata_type();
75 let type_resolver_ident = generic_types.type_resolver_ident();
76
77 let variant_ifs = details.variants.iter().map(|variant| {
80 let variant_ident = &variant.ident;
81 let variant_name = variant_ident.to_string();
82
83 let visit_one_variant_body = match &variant.fields {
84 syn::Fields::Named(fields) => {
85 let (
86 field_count,
87 field_composite_keyvals,
88 field_tuple_keyvals
89 ) = named_field_keyvals(path_to_scale_decode, fields);
90
91 quote!{
92 let fields = value.fields();
93 return if fields.has_unnamed_fields() {
94 if fields.remaining() != #field_count {
95 return Err(#path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::WrongLength {
96 actual_len: fields.remaining(),
97 expected_len: #field_count
98 }));
99 }
100 let vals = fields;
101 Ok(#path_to_type::#variant_ident { #(#field_tuple_keyvals),* })
102 } else {
103 let vals: #path_to_scale_decode::BTreeMap<Option<&str>, _> = fields
104 .map(|res| res.map(|item| (item.name(), item)))
105 .collect::<Result<_, _>>()?;
106 Ok(#path_to_type::#variant_ident { #(#field_composite_keyvals),* })
107 }
108 }
109 },
110 syn::Fields::Unnamed(fields) => {
111 let (
112 field_count,
113 field_vals
114 ) = unnamed_field_vals(path_to_scale_decode, fields);
115
116 quote!{
117 let fields = value.fields();
118 if fields.remaining() != #field_count {
119 return Err(#path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::WrongLength {
120 actual_len: fields.remaining(),
121 expected_len: #field_count
122 }));
123 }
124 let vals = fields;
125 return Ok(#path_to_type::#variant_ident ( #(#field_vals),* ))
126 }
127 },
128 syn::Fields::Unit => {
129 quote!{
130 return Ok(#path_to_type::#variant_ident)
131 }
132 },
133 };
134
135 quote!{
136 if value.name() == #variant_name {
137 #visit_one_variant_body
138 }
139 }
140 });
141
142 quote!(
143 const _: () = {
144 #visibility struct Visitor #visitor_impl_generics (
145 ::core::marker::PhantomData<#visitor_phantomdata_type>
146 );
147
148 use #path_to_scale_decode::vec;
149 use #path_to_scale_decode::ToString;
150
151 impl #impl_generics #path_to_scale_decode::IntoVisitor for #path_to_type #ty_generics #visitor_where_clause {
152 type AnyVisitor<#type_resolver_ident: #path_to_scale_decode::TypeResolver> = Visitor #visitor_ty_generics;
153 fn into_visitor<#type_resolver_ident: #path_to_scale_decode::TypeResolver>() -> Self::AnyVisitor<#type_resolver_ident> {
154 Visitor(::core::marker::PhantomData)
155 }
156 }
157
158 impl #visitor_impl_generics #path_to_scale_decode::Visitor for Visitor #visitor_ty_generics #visitor_where_clause {
159 type Error = #path_to_scale_decode::Error;
160 type Value<'scale, 'info> = #path_to_type #ty_generics;
161 type TypeResolver = #type_resolver_ident;
162
163 fn visit_variant<'scale, 'info>(
164 self,
165 value: &mut #path_to_scale_decode::visitor::types::Variant<'scale, 'info, Self::TypeResolver>,
166 type_id: <Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
167 ) -> Result<Self::Value<'scale, 'info>, Self::Error> {
168 #(
169 #variant_ifs
170 )*
171 Err(#path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::CannotFindVariant {
172 got: value.name().to_string(),
173 expected: vec![#(#variant_names),*]
174 }))
175 }
176 fn visit_composite<'scale, 'info>(
178 self,
179 value: &mut #path_to_scale_decode::visitor::types::Composite<'scale, 'info, Self::TypeResolver>,
180 _type_id: <Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
181 ) -> Result<Self::Value<'scale, 'info>, Self::Error> {
182 if value.remaining() != 1 {
183 return self.visit_unexpected(#path_to_scale_decode::visitor::Unexpected::Composite);
184 }
185 value.decode_item(self).unwrap()
186 }
187 fn visit_tuple<'scale, 'info>(
188 self,
189 value: &mut #path_to_scale_decode::visitor::types::Tuple<'scale, 'info, Self::TypeResolver>,
190 _type_id: <Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
191 ) -> Result<Self::Value<'scale, 'info>, Self::Error> {
192 if value.remaining() != 1 {
193 return self.visit_unexpected(#path_to_scale_decode::visitor::Unexpected::Tuple);
194 }
195 value.decode_item(self).unwrap()
196 }
197 }
198 };
199 )
200}
201
202fn generate_struct_impl(
203 attrs: TopLevelAttrs,
204 visibility: &syn::Visibility,
205 input: &DeriveInput,
206 details: &syn::DataStruct,
207) -> TokenStream2 {
208 let path_to_scale_decode = &attrs.crate_path;
209 let path_to_type: syn::Path = input.ident.clone().into();
210
211 let generic_types = handle_generics(&attrs, input.generics.clone());
212 let ty_generics = generic_types.ty_generics();
213 let impl_generics = generic_types.impl_generics();
214 let visitor_where_clause = generic_types.visitor_where_clause();
215 let visitor_ty_generics = generic_types.visitor_ty_generics();
216 let visitor_impl_generics = generic_types.visitor_impl_generics();
217 let visitor_phantomdata_type = generic_types.visitor_phantomdata_type();
218 let type_resolver_ident = generic_types.type_resolver_ident();
219
220 let (visit_composite_body, visit_tuple_body) = match &details.fields {
223 syn::Fields::Named(fields) => {
224 let (field_count, field_composite_keyvals, field_tuple_keyvals) =
225 named_field_keyvals(path_to_scale_decode, fields);
226
227 (
228 quote! {
229 if value.has_unnamed_fields() {
230 return self.visit_tuple(&mut value.as_tuple(), type_id)
231 }
232
233 let vals: #path_to_scale_decode::BTreeMap<Option<&str>, _> =
234 value.map(|res| res.map(|item| (item.name(), item))).collect::<Result<_, _>>()?;
235
236 Ok(#path_to_type { #(#field_composite_keyvals),* })
237 },
238 quote! {
239 if value.remaining() != #field_count {
240 return Err(#path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::WrongLength { actual_len: value.remaining(), expected_len: #field_count }));
241 }
242
243 let vals = value;
244
245 Ok(#path_to_type { #(#field_tuple_keyvals),* })
246 },
247 )
248 }
249 syn::Fields::Unnamed(fields) => {
250 let (field_count, field_vals) = unnamed_field_vals(path_to_scale_decode, fields);
251
252 (
253 quote! {
254 self.visit_tuple(&mut value.as_tuple(), type_id)
255 },
256 quote! {
257 if value.remaining() != #field_count {
258 return Err(#path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::WrongLength { actual_len: value.remaining(), expected_len: #field_count }));
259 }
260
261 let vals = value;
262
263 Ok(#path_to_type ( #( #field_vals ),* ))
264 },
265 )
266 }
267 syn::Fields::Unit => (
268 quote! {
269 self.visit_tuple(&mut value.as_tuple(), type_id)
270 },
271 quote! {
272 if value.remaining() > 0 {
273 return Err(#path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::WrongLength { actual_len: value.remaining(), expected_len: 0 }));
274 }
275 Ok(#path_to_type)
276 },
277 ),
278 };
279
280 quote!(
281 const _: () = {
282 #visibility struct Visitor #visitor_impl_generics (
283 ::core::marker::PhantomData<#visitor_phantomdata_type>
284 );
285
286 use #path_to_scale_decode::vec;
287 use #path_to_scale_decode::ToString;
288
289 impl #impl_generics #path_to_scale_decode::IntoVisitor for #path_to_type #ty_generics #visitor_where_clause {
290 type AnyVisitor<#type_resolver_ident: #path_to_scale_decode::TypeResolver> = Visitor #visitor_ty_generics;
291 fn into_visitor<#type_resolver_ident: #path_to_scale_decode::TypeResolver>() -> Self::AnyVisitor<#type_resolver_ident> {
292 Visitor(::core::marker::PhantomData)
293 }
294 }
295
296 impl #visitor_impl_generics #path_to_scale_decode::Visitor for Visitor #visitor_ty_generics #visitor_where_clause {
297 type Error = #path_to_scale_decode::Error;
298 type Value<'scale, 'info> = #path_to_type #ty_generics;
299 type TypeResolver = #type_resolver_ident;
300
301 fn visit_composite<'scale, 'info>(
302 self,
303 value: &mut #path_to_scale_decode::visitor::types::Composite<'scale, 'info, Self::TypeResolver>,
304 type_id: <Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
305 ) -> Result<Self::Value<'scale, 'info>, Self::Error> {
306 #visit_composite_body
307 }
308 fn visit_tuple<'scale, 'info>(
309 self,
310 value: &mut #path_to_scale_decode::visitor::types::Tuple<'scale, 'info, Self::TypeResolver>,
311 type_id: <Self::TypeResolver as #path_to_scale_decode::TypeResolver>::TypeId,
312 ) -> Result<Self::Value<'scale, 'info>, Self::Error> {
313 #visit_tuple_body
314 }
315 }
316
317 impl #impl_generics #path_to_scale_decode::DecodeAsFields for #path_to_type #ty_generics #visitor_where_clause {
318 fn decode_as_fields<'info, R: #path_to_scale_decode::TypeResolver>(
319 input: &mut &[u8],
320 fields: &mut dyn #path_to_scale_decode::FieldIter<'info, R::TypeId>,
321 types: &'info R
322 ) -> Result<Self, #path_to_scale_decode::Error>
323 {
324 let mut composite = #path_to_scale_decode::visitor::types::Composite::new(core::iter::empty(), input, fields, types, false);
325 use #path_to_scale_decode::{ Visitor, IntoVisitor };
326 let val = <#path_to_type #ty_generics>::into_visitor().visit_composite(&mut composite, Default::default());
327
328 composite.skip_decoding()?;
330 *input = composite.bytes_from_undecoded();
331
332 val.map_err(From::from)
333 }
334 }
335 };
336 )
337}
338
339fn named_field_keyvals<'f>(
341 path_to_scale_decode: &'f syn::Path,
342 fields: &'f syn::FieldsNamed,
343) -> (usize, impl Iterator<Item = TokenStream2> + 'f, impl Iterator<Item = TokenStream2> + 'f) {
344 let field_keyval_impls = fields.named.iter().map(move |f| {
345 let field_attrs = FieldAttrs::from_attributes(&f.attrs).unwrap_or_default();
346 let field_ident = f.ident.as_ref().expect("named field has ident");
347 let field_name = field_ident.to_string();
348 let skip_field = field_attrs.skip;
349
350 if skip_field {
352 return (
353 false,
354 quote!(#field_ident: ::core::default::Default::default()),
355 quote!(#field_ident: ::core::default::Default::default())
356 )
357 }
358
359 (
360 true,
362 quote!(#field_ident: {
364 let val = vals
365 .get(&Some(#field_name))
366 .ok_or_else(|| #path_to_scale_decode::Error::new(#path_to_scale_decode::error::ErrorKind::CannotFindField { name: #field_name.to_string() }))?
367 .clone();
368 val.decode_as_type().map_err(|e| e.at_field(#field_name))?
369 }),
370 quote!(#field_ident: {
372 let val = vals.next().expect("field count should have been checked already on tuple type; please file a bug report")?;
373 val.decode_as_type().map_err(|e| e.at_field(#field_name))?
374 })
375 )
376 });
377
378 let field_count = field_keyval_impls.clone().filter(|f| f.0).count();
380 let field_composite_keyvals = field_keyval_impls.clone().map(|v| v.1);
381 let field_tuple_keyvals = field_keyval_impls.map(|v| v.2);
382
383 (field_count, field_composite_keyvals, field_tuple_keyvals)
384}
385
386fn unnamed_field_vals<'f>(
388 _path_to_scale_decode: &'f syn::Path,
389 fields: &'f syn::FieldsUnnamed,
390) -> (usize, impl Iterator<Item = TokenStream2> + 'f) {
391 let field_val_impls = fields.unnamed.iter().enumerate().map(|(idx, f)| {
392 let field_attrs = FieldAttrs::from_attributes(&f.attrs).unwrap_or_default();
393 let skip_field = field_attrs.skip;
394
395 if skip_field {
397 return (false, quote!(::core::default::Default::default()));
398 }
399
400 (
401 true,
403 quote!({
405 let val = vals.next().expect("field count should have been checked already on tuple type; please file a bug report")?;
406 val.decode_as_type().map_err(|e| e.at_idx(#idx))?
407 }),
408 )
409 });
410
411 let field_count = field_val_impls.clone().filter(|f| f.0).count();
413 let field_vals = field_val_impls.map(|v| v.1);
414
415 (field_count, field_vals)
416}
417
418fn handle_generics(attrs: &TopLevelAttrs, generics: syn::Generics) -> GenericTypes {
419 let path_to_crate = &attrs.crate_path;
420
421 let type_resolver_ident =
422 syn::Ident::new(GenericTypes::TYPE_RESOLVER_IDENT_STR, Span::call_site());
423
424 let visitor_where_clause = {
426 let (_, _, where_clause) = generics.split_for_impl();
427 let mut where_clause = where_clause.cloned().unwrap_or(syn::parse_quote!(where));
428 if let Some(where_predicates) = &attrs.trait_bounds {
429 where_clause.predicates.extend(where_predicates.clone());
431 } else {
432 for param in generics.type_params() {
434 let ty = ¶m.ident;
435 where_clause.predicates.push(syn::parse_quote!(#ty: #path_to_crate::IntoVisitor));
436 }
437 }
438 where_clause
439 };
440
441 let visitor_phantomdata_type = {
443 let tys = generics.params.iter().filter_map::<syn::Type, _>(|p| match p {
444 syn::GenericParam::Type(ty) => {
445 let ty = &ty.ident;
446 Some(syn::parse_quote!(#ty))
447 }
448 syn::GenericParam::Lifetime(lt) => {
449 let lt = <.lifetime;
450 Some(syn::parse_quote!(& #lt ()))
451 }
452 syn::GenericParam::Const(_) => None,
454 });
455
456 let tys = tys.chain(core::iter::once(syn::parse_quote!(#type_resolver_ident)));
458
459 syn::parse_quote!( (#( #tys, )*) )
460 };
461
462 let visitor_generics = {
464 let mut type_generics = generics.clone();
465 let type_resolver_generic_param: syn::GenericParam =
466 syn::parse_quote!(#type_resolver_ident: #path_to_crate::TypeResolver);
467
468 type_generics.params.push(type_resolver_generic_param);
469 type_generics
470 };
471
472 let type_generics = generics;
474
475 GenericTypes {
476 type_generics,
477 type_resolver_ident,
478 visitor_generics,
479 visitor_phantomdata_type,
480 visitor_where_clause,
481 }
482}
483
484struct GenericTypes {
485 type_resolver_ident: syn::Ident,
486 type_generics: syn::Generics,
487 visitor_generics: syn::Generics,
488 visitor_where_clause: syn::WhereClause,
489 visitor_phantomdata_type: syn::Type,
490}
491
492impl GenericTypes {
493 const TYPE_RESOLVER_IDENT_STR: &'static str = "ScaleDecodeTypeResolver";
494
495 pub fn ty_generics(&self) -> syn::TypeGenerics<'_> {
496 let (_, ty_generics, _) = self.type_generics.split_for_impl();
497 ty_generics
498 }
499 pub fn impl_generics(&self) -> syn::ImplGenerics<'_> {
500 let (impl_generics, _, _) = self.type_generics.split_for_impl();
501 impl_generics
502 }
503 pub fn visitor_where_clause(&self) -> &syn::WhereClause {
504 &self.visitor_where_clause
505 }
506 pub fn visitor_ty_generics(&self) -> syn::TypeGenerics<'_> {
507 let (_, ty_generics, _) = self.visitor_generics.split_for_impl();
508 ty_generics
509 }
510 pub fn visitor_impl_generics(&self) -> syn::ImplGenerics<'_> {
511 let (impl_generics, _, _) = self.visitor_generics.split_for_impl();
512 impl_generics
513 }
514 pub fn visitor_phantomdata_type(&self) -> &syn::Type {
515 &self.visitor_phantomdata_type
516 }
517 pub fn type_resolver_ident(&self) -> &syn::Ident {
518 &self.type_resolver_ident
519 }
520}
521
522struct TopLevelAttrs {
523 crate_path: syn::Path,
525 trait_bounds: Option<Punctuated<syn::WherePredicate, syn::Token!(,)>>,
527}
528
529impl TopLevelAttrs {
530 fn parse(attrs: &[syn::Attribute]) -> darling::Result<Self> {
531 use darling::FromMeta;
532
533 #[derive(FromMeta)]
534 struct TopLevelAttrsInner {
535 #[darling(default)]
536 crate_path: Option<syn::Path>,
537 #[darling(default)]
538 trait_bounds: Option<Punctuated<syn::WherePredicate, syn::Token!(,)>>,
539 }
540
541 let mut res =
542 TopLevelAttrs { crate_path: syn::parse_quote!(::scale_decode), trait_bounds: None };
543
544 for attr in attrs {
546 if !attr.path().is_ident(ATTR_NAME) {
547 continue;
548 }
549 let meta = &attr.meta;
550 let parsed_attrs = TopLevelAttrsInner::from_meta(meta)?;
551
552 res.trait_bounds = parsed_attrs.trait_bounds;
553 if let Some(crate_path) = parsed_attrs.crate_path {
554 res.crate_path = crate_path;
555 }
556 }
557
558 Ok(res)
559 }
560}
561
562#[derive(Debug, FromAttributes, Default)]
564#[darling(attributes(decode_as_type, codec))]
565struct FieldAttrs {
566 #[darling(default)]
567 skip: bool,
568}