encoding_derive_helpers/
decode.rs

1// LNP/BP client-side-validation foundation libraries implementing LNPBP
2// specifications & standards (LNPBP-4, 7, 8, 9, 42, 81)
3//
4// Written in 2019-2021 by
5//     Dr. Maxim Orlovsky <orlovsky@pandoracore.com>
6//
7// To the extent possible under law, the author(s) have dedicated all
8// copyright and related and neighboring rights to this software to
9// the public domain worldwide. This software is distributed without
10// any warranty.
11//
12// You should have received a copy of the Apache 2.0 License along with this
13// software. If not, see <https://opensource.org/licenses/Apache-2.0>.
14
15use amplify::proc_attr::ParametrizedAttr;
16use proc_macro2::{Span, TokenStream as TokenStream2};
17use quote::{ToTokens, TokenStreamExt};
18use syn::spanned::Spanned;
19use syn::{
20    Data, DataEnum, DataStruct, DeriveInput, Error, Field, Fields, Ident,
21    ImplGenerics, Index, LitStr, Result, TypeGenerics, WhereClause,
22};
23
24use crate::param::{EncodingDerive, TlvDerive, CRATE, REPR, USE_TLV};
25
26/// Performs actual derivation of the decode trait using the provided
27/// information about trait parameters and requirements for TLV support.
28///
29/// You will find example of the function use in the
30/// [crate top-level documentation][crate].
31pub fn decode_derive(
32    attr_name: &'static str,
33    crate_name: Ident,
34    trait_name: Ident,
35    decode_name: Ident,
36    deserialize_name: Ident,
37    input: DeriveInput,
38    tlv_encoding: bool,
39) -> Result<TokenStream2> {
40    let (impl_generics, ty_generics, where_clause) =
41        input.generics.split_for_impl();
42    let ident_name = &input.ident;
43
44    let global_param = ParametrizedAttr::with(attr_name, &input.attrs)?;
45
46    match input.data {
47        Data::Struct(data) => decode_struct_impl(
48            attr_name,
49            &crate_name,
50            &trait_name,
51            &decode_name,
52            &deserialize_name,
53            data,
54            ident_name,
55            global_param,
56            impl_generics,
57            ty_generics,
58            where_clause,
59            tlv_encoding,
60        ),
61        Data::Enum(data) => decode_enum_impl(
62            attr_name,
63            &crate_name,
64            &trait_name,
65            &decode_name,
66            &deserialize_name,
67            data,
68            ident_name,
69            global_param,
70            impl_generics,
71            ty_generics,
72            where_clause,
73        ),
74        Data::Union(_) => Err(Error::new_spanned(
75            &input,
76            format!("Deriving `{}` is not supported in unions", trait_name),
77        )),
78    }
79}
80
81#[allow(clippy::too_many_arguments)]
82fn decode_struct_impl(
83    attr_name: &'static str,
84    crate_name: &Ident,
85    trait_name: &Ident,
86    decode_name: &Ident,
87    deserialize_name: &Ident,
88    data: DataStruct,
89    ident_name: &Ident,
90    mut global_param: ParametrizedAttr,
91    impl_generics: ImplGenerics,
92    ty_generics: TypeGenerics,
93    where_clause: Option<&WhereClause>,
94    tlv_encoding: bool,
95) -> Result<TokenStream2> {
96    let encoding = EncodingDerive::with(
97        &mut global_param,
98        crate_name,
99        true,
100        false,
101        false,
102    )?;
103
104    if !tlv_encoding && encoding.tlv.is_some() {
105        return Err(Error::new(
106            ident_name.span(),
107            format!("TLV extensions are not allowed in `{}`", attr_name),
108        ));
109    }
110
111    let inner_impl = match data.fields {
112        Fields::Named(ref fields) => decode_fields_impl(
113            attr_name,
114            crate_name,
115            trait_name,
116            decode_name,
117            deserialize_name,
118            ident_name,
119            &fields.named,
120            global_param,
121            false,
122            tlv_encoding,
123        )?,
124        Fields::Unnamed(ref fields) => decode_fields_impl(
125            attr_name,
126            crate_name,
127            trait_name,
128            decode_name,
129            deserialize_name,
130            ident_name,
131            &fields.unnamed,
132            global_param,
133            false,
134            tlv_encoding,
135        )?,
136        Fields::Unit => quote! {},
137    };
138
139    let import = encoding.use_crate;
140
141    Ok(quote! {
142        impl #impl_generics #import::#trait_name for #ident_name #ty_generics #where_clause {
143            #[inline]
144            fn #decode_name<D: ::std::io::Read>(mut d: D) -> ::core::result::Result<Self, #import::Error> {
145                use #import::#trait_name;
146                #inner_impl
147            }
148        }
149    })
150}
151
152#[allow(clippy::too_many_arguments)]
153fn decode_enum_impl(
154    attr_name: &'static str,
155    crate_name: &Ident,
156    trait_name: &Ident,
157    decode_name: &Ident,
158    deserialize_name: &Ident,
159    data: DataEnum,
160    ident_name: &Ident,
161    mut global_param: ParametrizedAttr,
162    impl_generics: ImplGenerics,
163    ty_generics: TypeGenerics,
164    where_clause: Option<&WhereClause>,
165) -> Result<TokenStream2> {
166    let encoding =
167        EncodingDerive::with(&mut global_param, crate_name, true, true, false)?;
168    let repr = encoding.repr;
169
170    let mut inner_impl = TokenStream2::new();
171
172    for (order, variant) in data.variants.iter().enumerate() {
173        let mut local_param =
174            ParametrizedAttr::with(attr_name, &variant.attrs)?;
175
176        // First, test individual attribute
177        let _ = EncodingDerive::with(
178            &mut local_param,
179            crate_name,
180            false,
181            true,
182            false,
183        )?;
184        // Second, combine global and local together
185        let mut combined = global_param.clone().merged(local_param.clone())?;
186        combined.args.remove(REPR);
187        combined.args.remove(CRATE);
188        let encoding = EncodingDerive::with(
189            &mut combined,
190            crate_name,
191            false,
192            true,
193            false,
194        )?;
195
196        if encoding.skip {
197            continue;
198        }
199
200        let field_impl = match variant.fields {
201            Fields::Named(ref fields) => decode_fields_impl(
202                attr_name,
203                crate_name,
204                trait_name,
205                decode_name,
206                deserialize_name,
207                ident_name,
208                &fields.named,
209                local_param,
210                true,
211                false,
212            )?,
213            Fields::Unnamed(ref fields) => decode_fields_impl(
214                attr_name,
215                crate_name,
216                trait_name,
217                decode_name,
218                deserialize_name,
219                ident_name,
220                &fields.unnamed,
221                local_param,
222                true,
223                false,
224            )?,
225            Fields::Unit => TokenStream2::new(),
226        };
227
228        let ident = &variant.ident;
229        let value = match (encoding.value, encoding.by_order) {
230            (Some(val), _) => val.to_token_stream(),
231            (None, true) => Index::from(order as usize).to_token_stream(),
232            (None, false) => quote! { Self::#ident as #repr },
233        };
234
235        inner_impl.append_all(quote_spanned! { variant.span() =>
236            x if x == #value => {
237                Self::#ident {
238                    #field_impl
239                }
240            }
241        });
242    }
243
244    let import = encoding.use_crate;
245    let enum_name = LitStr::new(&ident_name.to_string(), Span::call_site());
246
247    Ok(quote! {
248        impl #impl_generics #import::#trait_name for #ident_name #ty_generics #where_clause {
249            fn #decode_name<D: ::std::io::Read>(mut d: D) -> ::core::result::Result<Self, #import::Error> {
250                use #import::#trait_name;
251                Ok(match #repr::#decode_name(&mut d)? {
252                    #inner_impl
253                    unknown => Err(#import::Error::EnumValueNotKnown(#enum_name, unknown as usize))?
254                })
255            }
256        }
257    })
258}
259
260#[allow(clippy::too_many_arguments)]
261fn decode_fields_impl<'a>(
262    attr_name: &'static str,
263    crate_name: &Ident,
264    trait_name: &Ident,
265    decode_name: &Ident,
266    deserialize_name: &Ident,
267    ident_name: &Ident,
268    fields: impl IntoIterator<Item = &'a Field>,
269    mut parent_param: ParametrizedAttr,
270    is_enum: bool,
271    tlv_encoding: bool,
272) -> Result<TokenStream2> {
273    let mut stream = TokenStream2::new();
274
275    let use_tlv = parent_param.args.contains_key(USE_TLV);
276    parent_param.args.remove(CRATE);
277    parent_param.args.remove(USE_TLV);
278
279    if !tlv_encoding && use_tlv {
280        return Err(Error::new(
281            Span::call_site(),
282            format!("TLV extensions are not allowed in `{}`", attr_name),
283        ));
284    }
285
286    let parent_attr = EncodingDerive::with(
287        &mut parent_param.clone(),
288        crate_name,
289        false,
290        is_enum,
291        false,
292    )?;
293    let import = parent_attr.use_crate;
294
295    let mut skipped_fields = vec![];
296    let mut strict_fields = vec![];
297    let mut tlv_fields = bmap! {};
298    let mut tlv_aggregator = None;
299
300    for (index, field) in fields.into_iter().enumerate() {
301        let mut local_param = ParametrizedAttr::with(attr_name, &field.attrs)?;
302
303        // First, test individual attribute
304        let _ = EncodingDerive::with(
305            &mut local_param,
306            crate_name,
307            false,
308            is_enum,
309            use_tlv,
310        )?;
311        // Second, combine global and local together
312        let mut combined = parent_param.clone().merged(local_param)?;
313        let encoding = EncodingDerive::with(
314            &mut combined,
315            crate_name,
316            false,
317            is_enum,
318            use_tlv,
319        )?;
320
321        let name = field
322            .ident
323            .as_ref()
324            .map(Ident::to_token_stream)
325            .unwrap_or_else(|| Index::from(index).to_token_stream());
326
327        if encoding.skip {
328            skipped_fields.push(name);
329            continue;
330        }
331
332        encoding.tlv.unwrap_or(TlvDerive::None).process(
333            field,
334            name,
335            &mut strict_fields,
336            &mut tlv_fields,
337            &mut tlv_aggregator,
338        )?;
339    }
340
341    for name in strict_fields {
342        stream.append_all(quote_spanned! { Span::call_site() =>
343            #name: #import::#trait_name::#decode_name(&mut d)?,
344        });
345    }
346
347    let mut default_fields = skipped_fields;
348    default_fields.extend(tlv_fields.values().map(|(n, _)| n).cloned());
349    default_fields.extend(tlv_aggregator.clone());
350    for name in default_fields {
351        stream.append_all(quote_spanned! { Span::call_site() =>
352            #name: Default::default(),
353        });
354    }
355
356    if !is_enum {
357        if use_tlv {
358            let mut inner = TokenStream2::new();
359            for (type_no, (name, optional)) in tlv_fields {
360                if optional {
361                    inner.append_all(quote_spanned! { Span::call_site() =>
362                        #type_no => s.#name = Some(#import::#trait_name::#deserialize_name(bytes)?),
363                    });
364                } else {
365                    inner.append_all(quote_spanned! { Span::call_site() =>
366                        #type_no => s.#name = #import::#trait_name::#deserialize_name(bytes)?,
367                    });
368                }
369            }
370
371            let aggregator = if let Some(ref tlv_aggregator) = tlv_aggregator {
372                quote_spanned! { Span::call_site() =>
373                    _ if *type_no % 2 == 0 => return Err(#import::TlvError::UnknownEvenType(*type_no).into()),
374                    _ => { s.#tlv_aggregator.insert(type_no, bytes); },
375                }
376            } else {
377                quote_spanned! { Span::call_site() =>
378                    _ if *type_no % 2 == 0 => return Err(#import::TlvError::UnknownEvenType(*type_no).into()),
379                    _ => {}
380                }
381            };
382
383            stream = quote_spanned! { Span::call_site() =>
384                let mut s = #ident_name { #stream };
385                let tlvs = internet2::tlv::Stream::#decode_name(&mut d)?;
386            };
387
388            stream.append_all(quote_spanned! { Span::call_site() =>
389                for (type_no, bytes) in tlvs {
390                    match *type_no as usize {
391                        #inner
392
393                        #aggregator
394                    }
395                }
396                Ok(s)
397            });
398        } else {
399            stream = quote_spanned! { Span::call_site() =>
400                Ok(#ident_name { #stream })
401            };
402        }
403    }
404
405    Ok(stream)
406}