encoding_derive_helpers/
encode.rs

1// LNP/BP client-side-validation foundation libraries implementing LNPBP
2// specifications & standards (LNPBP-4, 7, 8, 9, 42, 81)
3//
4// Written in 2019-2021 by
5//     Dr. Maxim Orlovsky <orlovsky@pandoracore.com>
6//
7// To the extent possible under law, the author(s) have dedicated all
8// copyright and related and neighboring rights to this software to
9// the public domain worldwide. This software is distributed without
10// any warranty.
11//
12// You should have received a copy of the Apache 2.0 License along with this
13// software. If not, see <https://opensource.org/licenses/Apache-2.0>.
14
15use amplify::proc_attr::ParametrizedAttr;
16use proc_macro2::{Span, TokenStream as TokenStream2};
17use quote::{ToTokens, TokenStreamExt};
18use syn::spanned::Spanned;
19use syn::{
20    Data, DataEnum, DataStruct, DeriveInput, Error, Field, Fields, Ident,
21    ImplGenerics, Index, Result, TypeGenerics, WhereClause,
22};
23
24use crate::param::{EncodingDerive, TlvDerive, CRATE, REPR, USE_TLV};
25
26/// Performs actual derivation of the encode trait using the provided
27/// information about trait parameters and requirements for TLV support.
28///
29/// You will find example of the function use in the
30/// [crate top-level documentation][crate].
31pub fn encode_derive(
32    attr_name: &'static str,
33    crate_name: Ident,
34    trait_name: Ident,
35    encode_name: Ident,
36    serialize_name: Ident,
37    input: DeriveInput,
38    tlv_encoding: bool,
39) -> Result<TokenStream2> {
40    let (impl_generics, ty_generics, where_clause) =
41        input.generics.split_for_impl();
42    let ident_name = &input.ident;
43
44    let global_param = ParametrizedAttr::with(attr_name, &input.attrs)?;
45
46    match input.data {
47        Data::Struct(data) => encode_struct_impl(
48            attr_name,
49            &crate_name,
50            &trait_name,
51            &encode_name,
52            &serialize_name,
53            data,
54            ident_name,
55            global_param,
56            impl_generics,
57            ty_generics,
58            where_clause,
59            tlv_encoding,
60        ),
61        Data::Enum(data) => encode_enum_impl(
62            attr_name,
63            &crate_name,
64            &trait_name,
65            &encode_name,
66            &serialize_name,
67            data,
68            ident_name,
69            global_param,
70            impl_generics,
71            ty_generics,
72            where_clause,
73        ),
74        Data::Union(_) => Err(Error::new_spanned(
75            &input,
76            format!("Deriving `{}` is not supported in unions", trait_name),
77        )),
78    }
79}
80
81#[allow(clippy::too_many_arguments)]
82fn encode_struct_impl(
83    attr_name: &'static str,
84    crate_name: &Ident,
85    trait_name: &Ident,
86    encode_name: &Ident,
87    serialize_name: &Ident,
88    data: DataStruct,
89    ident_name: &Ident,
90    mut global_param: ParametrizedAttr,
91    impl_generics: ImplGenerics,
92    ty_generics: TypeGenerics,
93    where_clause: Option<&WhereClause>,
94    tlv_encoding: bool,
95) -> Result<TokenStream2> {
96    let encoding = EncodingDerive::with(
97        &mut global_param,
98        crate_name,
99        true,
100        false,
101        false,
102    )?;
103
104    if !tlv_encoding && encoding.tlv.is_some() {
105        return Err(Error::new(
106            ident_name.span(),
107            format!("TLV extensions are not allowed in `{}`", attr_name),
108        ));
109    }
110
111    let inner_impl = match data.fields {
112        Fields::Named(ref fields) => encode_fields_impl(
113            attr_name,
114            crate_name,
115            trait_name,
116            encode_name,
117            serialize_name,
118            &fields.named,
119            global_param,
120            false,
121            tlv_encoding,
122        )?,
123        Fields::Unnamed(ref fields) => encode_fields_impl(
124            attr_name,
125            crate_name,
126            trait_name,
127            encode_name,
128            serialize_name,
129            &fields.unnamed,
130            global_param,
131            false,
132            tlv_encoding,
133        )?,
134        Fields::Unit => quote! { Ok(0) },
135    };
136
137    let import = encoding.use_crate;
138
139    Ok(quote! {
140        impl #impl_generics #import::#trait_name for #ident_name #ty_generics #where_clause {
141            fn #encode_name<E: ::std::io::Write>(&self, mut e: E) -> ::core::result::Result<usize, #import::Error> {
142                use #import::#trait_name;
143                let mut len = 0;
144                let data = self;
145                #inner_impl
146                Ok(len)
147            }
148        }
149    })
150}
151
152#[allow(clippy::too_many_arguments)]
153fn encode_enum_impl(
154    attr_name: &'static str,
155    crate_name: &Ident,
156    trait_name: &Ident,
157    encode_name: &Ident,
158    serialize_name: &Ident,
159    data: DataEnum,
160    ident_name: &Ident,
161    mut global_param: ParametrizedAttr,
162    impl_generics: ImplGenerics,
163    ty_generics: TypeGenerics,
164    where_clause: Option<&WhereClause>,
165) -> Result<TokenStream2> {
166    let encoding =
167        EncodingDerive::with(&mut global_param, crate_name, true, true, false)?;
168    let repr = encoding.repr;
169
170    let mut inner_impl = TokenStream2::new();
171
172    for (order, variant) in data.variants.iter().enumerate() {
173        let mut local_param =
174            ParametrizedAttr::with(attr_name, &variant.attrs)?;
175
176        // First, test individual attribute
177        let _ = EncodingDerive::with(
178            &mut local_param,
179            crate_name,
180            false,
181            true,
182            false,
183        )?;
184        // Second, combine global and local together
185        let mut combined = global_param.clone().merged(local_param.clone())?;
186        combined.args.remove(REPR);
187        combined.args.remove(CRATE);
188        let encoding = EncodingDerive::with(
189            &mut combined,
190            crate_name,
191            false,
192            true,
193            false,
194        )?;
195
196        if encoding.skip {
197            continue;
198        }
199
200        let captures = variant
201            .fields
202            .iter()
203            .enumerate()
204            .map(|(i, f)| {
205                f.ident.as_ref().map(Ident::to_token_stream).unwrap_or_else(
206                    || {
207                        Ident::new(&format!("_{}", i), Span::call_site())
208                            .to_token_stream()
209                    },
210                )
211            })
212            .collect::<Vec<_>>();
213
214        let (field_impl, bra_captures_ket) = match variant.fields {
215            Fields::Named(ref fields) => (
216                encode_fields_impl(
217                    attr_name,
218                    crate_name,
219                    trait_name,
220                    encode_name,
221                    serialize_name,
222                    &fields.named,
223                    local_param,
224                    true,
225                    false,
226                )?,
227                quote! { { #( #captures ),* } },
228            ),
229            Fields::Unnamed(ref fields) => (
230                encode_fields_impl(
231                    attr_name,
232                    crate_name,
233                    trait_name,
234                    encode_name,
235                    serialize_name,
236                    &fields.unnamed,
237                    local_param,
238                    true,
239                    false,
240                )?,
241                quote! { ( #( #captures ),* ) },
242            ),
243            Fields::Unit => (TokenStream2::new(), TokenStream2::new()),
244        };
245
246        let captures = match captures.len() {
247            0 => quote! {},
248            _ => quote! { let data = ( #( #captures ),* , ); },
249        };
250
251        let ident = &variant.ident;
252        let value = match (encoding.value, encoding.by_order) {
253            (Some(val), _) => val.to_token_stream(),
254            (None, true) => Index::from(order as usize).to_token_stream(),
255            (None, false) => quote! { Self::#ident },
256        };
257
258        inner_impl.append_all(quote_spanned! { variant.span() =>
259            Self::#ident #bra_captures_ket => {
260                len += (#value as #repr).#encode_name(&mut e)?;
261                #captures
262                #field_impl
263            }
264        });
265    }
266
267    let import = encoding.use_crate;
268
269    Ok(quote! {
270        impl #impl_generics #import::#trait_name for #ident_name #ty_generics #where_clause {
271            #[inline]
272            fn #encode_name<E: ::std::io::Write>(&self, mut e: E) -> ::core::result::Result<usize, #import::Error> {
273                use #import::#trait_name;
274                let mut len = 0;
275                match self {
276                    #inner_impl
277                }
278                Ok(len)
279            }
280        }
281    })
282}
283
284#[allow(clippy::too_many_arguments)]
285fn encode_fields_impl<'a>(
286    attr_name: &'static str,
287    crate_name: &Ident,
288    _trait_name: &Ident,
289    encode_name: &Ident,
290    serialize_name: &Ident,
291    fields: impl IntoIterator<Item = &'a Field>,
292    mut parent_param: ParametrizedAttr,
293    is_enum: bool,
294    tlv_encoding: bool,
295) -> Result<TokenStream2> {
296    let mut stream = TokenStream2::new();
297
298    let use_tlv = parent_param.args.contains_key(USE_TLV);
299    parent_param.args.remove(CRATE);
300    parent_param.args.remove(USE_TLV);
301
302    if !tlv_encoding && use_tlv {
303        return Err(Error::new(
304            Span::call_site(),
305            format!("TLV extensions are not allowed in `{}`", attr_name),
306        ));
307    }
308
309    let mut strict_fields = vec![];
310    let mut tlv_fields = bmap! {};
311    let mut tlv_aggregator = None;
312
313    for (index, field) in fields.into_iter().enumerate() {
314        let mut local_param = ParametrizedAttr::with(attr_name, &field.attrs)?;
315
316        // First, test individual attribute
317        let _ = EncodingDerive::with(
318            &mut local_param,
319            crate_name,
320            false,
321            is_enum,
322            use_tlv,
323        )?;
324        // Second, combine global and local together
325        let mut combined = parent_param.clone().merged(local_param)?;
326        let encoding = EncodingDerive::with(
327            &mut combined,
328            crate_name,
329            false,
330            is_enum,
331            use_tlv,
332        )?;
333
334        if encoding.skip {
335            continue;
336        }
337
338        let index = Index::from(index).to_token_stream();
339        let name = if is_enum {
340            index
341        } else {
342            field
343                .ident
344                .as_ref()
345                .map(Ident::to_token_stream)
346                .unwrap_or(index)
347        };
348
349        encoding.tlv.unwrap_or(TlvDerive::None).process(
350            field,
351            name,
352            &mut strict_fields,
353            &mut tlv_fields,
354            &mut tlv_aggregator,
355        )?;
356    }
357
358    for name in strict_fields {
359        stream.append_all(quote_spanned! { Span::call_site() =>
360            len += data.#name.#encode_name(&mut e)?;
361        })
362    }
363
364    if use_tlv {
365        stream.append_all(quote_spanned! { Span::call_site() =>
366            let mut tlvs = internet2::tlv::Stream::default();
367        });
368        for (type_no, (name, optional)) in tlv_fields {
369            if optional {
370                stream.append_all(quote_spanned! { Span::call_site() =>
371                    if let Some(val) = &data.#name {
372                        tlvs.insert(#type_no.into(), val.#serialize_name()?);
373                    }
374                });
375            } else {
376                stream.append_all(quote_spanned! { Span::call_site() =>
377                    if data.#name.iter().count() > 0 {
378                        tlvs.insert(#type_no.into(), data.#name.#serialize_name()?);
379                    }
380                });
381            }
382        }
383        if let Some(name) = tlv_aggregator {
384            stream.append_all(quote_spanned! { Span::call_site() =>
385                for (type_no, val) in &data.#name {
386                    tlvs.insert(*type_no, val);
387                }
388            });
389        }
390
391        stream.append_all(quote_spanned! { Span::call_site() =>
392            len += tlvs.#encode_name(&mut e)?;
393        })
394    }
395
396    Ok(stream)
397}