bitcoin_consensus_derive/
lib.rs

1///! proc-macro to derive a bitcoin `Encodable` and `Decodable` implementation for a struct
2use proc_macro::TokenStream;
3use proc_macro2::{Ident, TokenStream as TokenStream2};
4use quote::{quote, ToTokens};
5use syn::{
6    parse_macro_input, Data, DataStruct, DeriveInput, Fields, FieldsNamed, FieldsUnnamed, Index,
7    Type,
8};
9
10/// Derive `Encodable` for a struct.
11///
12/// Notes:
13/// - all number fields will be encoded in big endian, unlike rust-bitcoin
14/// - all `Option` fields will be encoded with a `bool` indicating whether the field is `Some` or `None`
15#[proc_macro_derive(Encodable)]
16pub fn derive_encodable(input: TokenStream) -> TokenStream {
17    let DeriveInput { ident, data, .. } = parse_macro_input!(input);
18    // handle struct
19    let output = if let Data::Struct(DataStruct {
20        fields: Fields::Named(FieldsNamed { named: fields, .. }),
21        ..
22    }) = data
23    {
24        let field_tokens = fields.iter().map(|field| {
25            let field_name = field.ident.as_ref().unwrap();
26            let field_type = &field.ty;
27            generate_field_encode(true, field_name, field_type)
28        });
29        let output = quote! {
30            impl bitcoin::consensus::Encodable for #ident {
31                fn consensus_encode<W: serde_bolt::io::Write + ?Sized>(
32                    &self,
33                    w: &mut W,
34                ) -> core::result::Result<usize, serde_bolt::io::Error> {
35                    let mut len = 0;
36                    #( #field_tokens )*
37                    Ok(len)
38                }
39            }
40        };
41        output
42    } else if let Data::Struct(DataStruct {
43        fields: Fields::Unnamed(FieldsUnnamed {
44            unnamed: fields, ..
45        }),
46        ..
47    }) = data
48    {
49        let field_tokens = fields.iter().enumerate().map(|(i, field)| {
50            let field_name = Index::from(i);
51            let field_type = &field.ty;
52            generate_field_encode(true, &field_name, field_type)
53        });
54        let output = quote! {
55            impl bitcoin::consensus::Encodable for #ident {
56                fn consensus_encode<W: serde_bolt::io::Write + ?Sized>(
57                    &self,
58                    w: &mut W,
59                ) -> core::result::Result<usize, serde_bolt::io::Error> {
60                    let mut len = 0;
61                    #( #field_tokens )*
62                    Ok(len)
63                }
64            }
65        };
66        output
67    } else {
68        unimplemented!()
69    };
70    let output = quote! {
71        #output
72
73        // delegate to `consensus_encode` because that is already big-endian for top-level fields
74        use serde_bolt::bitcoin::consensus::{Encodable as _, Decodable as _};
75        impl serde_bolt::BigEndianEncodable for #ident {
76            fn consensus_encode_be<W: serde_bolt::io::Write + ?Sized>(
77                &self,
78                w: &mut W,
79            ) -> core::result::Result<usize, bitcoin::consensus::encode::Error> {
80                self.consensus_encode(w).map_err(bitcoin::consensus::encode::Error::from)
81            }
82            fn consensus_decode_be<R: serde_bolt::io::Read + ?Sized>(
83                r: &mut R,
84            ) -> core::result::Result<Self, bitcoin::consensus::encode::Error> {
85                Self::consensus_decode(r)
86            }
87        }
88    };
89    output.into()
90}
91
92/// Derive `Decodable` for a struct.
93///
94/// See [`derive_encodable`] for notes.
95#[proc_macro_derive(Decodable)]
96pub fn derive_decodable(input: TokenStream) -> TokenStream {
97    let DeriveInput { ident, data, .. } = parse_macro_input!(input);
98    // handle struct
99    if let syn::Data::Struct(syn::DataStruct {
100        fields: syn::Fields::Named(FieldsNamed { named: fields, .. }),
101        ..
102    }) = data
103    {
104        let field_tokens = fields.iter().map(|field| {
105            let field_name = field.ident.as_ref().unwrap();
106            let field_type = &field.ty;
107            generate_field_decode(field_name, field_type)
108        });
109        let field_names = fields.iter().map(|field| {
110            let field_name = field.ident.as_ref().unwrap();
111            quote! {
112                #field_name,
113            }
114        });
115        let output = quote! {
116            impl bitcoin::consensus::Decodable for #ident {
117                fn consensus_decode<R: serde_bolt::io::Read + ?Sized>(
118                    r: &mut R,
119                ) -> core::result::Result<Self, bitcoin::consensus::encode::Error> {
120                    #( #field_tokens )*
121                    Ok(Self {
122                        #( #field_names )*
123                    })
124                }
125            }
126        };
127        return output.into();
128    } else if let Data::Struct(DataStruct {
129        fields: Fields::Unnamed(FieldsUnnamed {
130            unnamed: fields, ..
131        }),
132        ..
133    }) = data
134    {
135        let field_tokens = fields.iter().enumerate().map(|(i, field)| {
136            let field_name = Ident::new(&format!("field_{}", i), proc_macro2::Span::call_site());
137            let field_type = &field.ty;
138            generate_field_decode(&field_name, field_type)
139        });
140        let field_names = fields.iter().enumerate().map(|(i, _field)| {
141            let field_name = Ident::new(&format!("field_{}", i), proc_macro2::Span::call_site());
142            quote! {
143                #field_name,
144            }
145        });
146        let output = quote! {
147            impl bitcoin::consensus::Decodable for #ident {
148                fn consensus_decode<R: serde_bolt::io::Read + ?Sized>(
149                    r: &mut R,
150                ) -> core::result::Result<Self, bitcoin::consensus::encode::Error> {
151                    #( #field_tokens )*
152                    Ok(Self(
153                        #( #field_names )*
154                    ))
155                }
156            }
157        };
158        return output.into();
159    } else {
160        unimplemented!()
161    }
162}
163
164fn generate_field_encode(
165    is_self: bool,
166    field_name: &dyn ToTokens,
167    field_type: &Type,
168) -> TokenStream2 {
169    let field_access = if is_self {
170        quote! {
171            self.#field_name
172        }
173    } else {
174        quote! {
175            #field_name
176        }
177    };
178    if get_array_length(field_type).is_some() {
179        quote! {
180            for el in &#field_access {
181                len += el.consensus_encode(w)?;
182            }
183        }
184    } else if is_numeric_type(field_type) {
185        quote! {
186            let buf = #field_access.to_be_bytes();
187            w.write_all(&buf)?;
188        }
189    } else if let Some(inner_type) = extract_option_type(field_type) {
190        let inner_tokens = generate_field_encode(
191            false,
192            &Ident::new("inner", proc_macro2::Span::call_site()),
193            inner_type,
194        );
195        quote! {
196            len += #field_access.is_some().consensus_encode(w)?;
197            if let Some(inner) = &#field_access {
198                #inner_tokens
199            }
200        }
201    } else {
202        quote! {
203            len += #field_access.consensus_encode(w)?;
204        }
205    }
206}
207
208fn generate_field_decode(var: &Ident, field_type: &Type) -> TokenStream2 {
209    let output = if let Some(size) = get_array_length(field_type) {
210        quote! {
211            use core::convert::TryInto;
212            use alloc::vec::Vec;
213            let mut v = Vec::with_capacity(#size);
214            for _ in 0..#size {
215                let el = bitcoin::consensus::Decodable::consensus_decode(r)?;
216                v.push(el);
217            }
218            let #var = v.try_into().unwrap();
219        }
220    } else if is_numeric_type(field_type) {
221        quote! {
222            let mut buf = [0u8; core::mem::size_of::<#field_type>()];
223            r.read_exact(&mut buf)?;
224            let #var = #field_type::from_be_bytes(buf);
225        }
226    } else if let Some(inner_type) = extract_option_type(field_type) {
227        let inner_tokens = generate_field_decode(
228            &Ident::new("inner", proc_macro2::Span::call_site()),
229            inner_type,
230        );
231        quote! {
232            let is_some: bool = bitcoin::consensus::Decodable::consensus_decode(r)?;
233            let #var = if is_some {
234                let inner = {
235                    #inner_tokens
236                    inner
237                };
238                Some(inner)
239            } else {
240                None
241            };
242        }
243    } else {
244        quote! {
245            let #var = bitcoin::consensus::Decodable::consensus_decode(r)?;
246        }
247    };
248    output
249}
250
251fn extract_option_type(ty: &syn::Type) -> Option<&syn::Type> {
252    if let syn::Type::Path(syn::TypePath {
253        path: syn::Path { segments, .. },
254        ..
255    }) = ty
256    {
257        if let Some(syn::PathSegment {
258            ident,
259            arguments:
260                syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments { args, .. }),
261            ..
262        }) = segments.first()
263        {
264            if ident == "Option" {
265                if let Some(syn::GenericArgument::Type(inner_type)) = args.first() {
266                    return Some(inner_type);
267                }
268            }
269        }
270    }
271    None
272}
273
274fn is_numeric_type(ty: &syn::Type) -> bool {
275    if let syn::Type::Path(syn::TypePath {
276        path: syn::Path { segments, .. },
277        ..
278    }) = ty
279    {
280        if let Some(syn::PathSegment { ident, .. }) = segments.first() {
281            if ident == "u8"
282                || ident == "u16"
283                || ident == "u32"
284                || ident == "u64"
285                || ident == "u128"
286                || ident == "i8"
287                || ident == "i16"
288                || ident == "i32"
289                || ident == "i64"
290                || ident == "i128"
291            {
292                return true;
293            }
294        }
295    }
296    false
297}
298
299fn get_array_length(ty: &syn::Type) -> Option<usize> {
300    if let syn::Type::Array(syn::TypeArray { len, .. }) = ty {
301        if let syn::Expr::Lit(syn::ExprLit {
302            lit: syn::Lit::Int(int),
303            ..
304        }) = len
305        {
306            if let Ok(value) = int.base10_parse::<usize>() {
307                return Some(value);
308            }
309        }
310    }
311    None
312}