bolt_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Ident, TokenStream as TokenStream2};
3use quote::{format_ident, quote, ToTokens};
4use syn::{parse_macro_input, Data, DataEnum, DeriveInput, Error, Fields, Lit, Type};
5
6/// Serialize a message with a type prefix, in BOLT style
7#[proc_macro_derive(SerBolt, attributes(message_id))]
8pub fn derive_ser_bolt(input: TokenStream) -> TokenStream {
9    let input1 = input.clone();
10    let DeriveInput { ident, attrs, .. } = parse_macro_input!(input1);
11    let message_id = attrs
12        .into_iter()
13        .filter(|a| a.path.is_ident("message_id"))
14        .next()
15        .map(|a| a.tokens)
16        .unwrap_or_else(|| {
17            Error::new(ident.span(), "missing message_id attribute").into_compile_error()
18        });
19
20    let output = quote! {
21        impl SerBolt for #ident {
22            fn as_vec(&self) -> Vec<u8> {
23                let message_type = Self::TYPE;
24                let mut buf = message_type.to_be_bytes().to_vec();
25                let mut val_buf = to_vec(&self).expect("serialize");
26                buf.append(&mut val_buf);
27                buf
28            }
29
30            fn name(&self) -> &'static str {
31                stringify!(#ident)
32            }
33        }
34
35        impl DeBolt for #ident {
36            const TYPE: u16 = #message_id;
37            fn from_vec(mut ser: Vec<u8>) -> Result<Self> {
38                let mut cursor = serde_bolt::io::Cursor::new(&ser);
39                let message_type = cursor.read_u16_be()?;
40                if message_type != Self::TYPE {
41                    return Err(Error::UnexpectedType(message_type));
42                }
43                let res = Decodable::consensus_decode(&mut cursor)?;
44                if cursor.position() as usize != ser.len() {
45                    return Err(Error::TrailingBytes(cursor.position() as usize - ser.len(), Self::TYPE));
46                }
47                Ok(res)
48            }
49        }
50    };
51    output.into()
52}
53
54#[proc_macro_derive(SerBoltTlvOptions, attributes(tlv_tag))]
55pub fn derive_ser_bolt_tlv(input: TokenStream) -> TokenStream {
56    let input = parse_macro_input!(input as DeriveInput);
57    let ident = &input.ident;
58
59    let mut encode_entries: Vec<(u64, proc_macro2::TokenStream)> = Vec::new();
60    let mut decode_entries: Vec<(u64, proc_macro2::TokenStream)> = Vec::new();
61    let mut decode_temp_declarations: Vec<proc_macro2::TokenStream> = Vec::new();
62    let mut decode_fields: Vec<proc_macro2::TokenStream> = Vec::new();
63
64    // traverse the fields, build the needed lists
65    if let Data::Struct(data_struct) = &input.data {
66        if let Fields::Named(fields_named) = &data_struct.fields {
67            for field in fields_named.named.iter() {
68                let field_name = field.ident.as_ref().unwrap();
69                let field_type = &field.ty;
70                let var_name = format_ident!("{}", field_name);
71
72                if let Some(attr) = field.attrs.iter().find(|a| a.path.is_ident("tlv_tag")) {
73                    match attr.parse_meta() {
74                        Ok(syn::Meta::NameValue(meta_name_value)) => {
75                            if let Lit::Int(lit_int) = meta_name_value.lit {
76                                let tlv_tag = lit_int
77                                    .base10_parse::<u64>()
78                                    .expect("tlv_tag should be a valid u64");
79                                encode_entries.push((
80                                    tlv_tag,
81                                    quote! {
82                                        (#tlv_tag, self.#var_name.as_ref().map(|f| crate::model::SerBoltTlvWriteWrap(f)), option),
83                                    },
84                                ));
85                                decode_entries.push((
86                                    tlv_tag,
87                                    quote! {
88                                        (#tlv_tag, #var_name, option),
89                                    },
90                                ));
91                                let inner_type =
92                                    unwrap_option(field_type).expect("Option type expected");
93                                decode_temp_declarations.push(quote! {
94                                    let mut #var_name: Option<crate::model::SerBoltTlvReadWrap<#inner_type>> = None;
95                                });
96                                decode_fields.push(quote! {
97                                    #var_name: #var_name.map(|w| w.0),
98                                });
99                            } else {
100                                eprintln!("Warning: `tlv_tag` attribute value must be an integer.");
101                            }
102                        }
103                        _ => eprintln!("Failed to parse `tlv_tag` attribute."),
104                    }
105                } else {
106                    eprintln!("Warning: Missing `tlv_tag` attribute for field `{}`.", field_name);
107                }
108            }
109        }
110    }
111
112    // sort the entries into ascending order
113    encode_entries.sort_by_key(|entry| entry.0);
114    decode_entries.sort_by_key(|entry| entry.0);
115    let sorted_encode_entries: Vec<_> = encode_entries.iter().map(|(_tag, ts)| ts).collect();
116    let sorted_decode_entries: Vec<_> = decode_entries.iter().map(|(_tag, ts)| ts).collect();
117
118    // generate the output
119    let output = quote! {
120        impl Encodable for #ident {
121            fn consensus_encode<W: io::Write + ?Sized>(
122                &self,
123                w: &mut W,
124            ) -> core::result::Result<usize, io::Error> {
125                let mut mw = crate::util::MeasuredWriter::wrap(w);
126                lightning::encode_tlv_stream!(&mut mw, {
127                    #( #sorted_encode_entries )*
128                });
129                Ok(mw.len())
130            }
131        }
132
133        impl Decodable for #ident {
134            fn consensus_decode<R: io::Read + ?Sized>(
135                r: &mut R,
136            ) -> core::result::Result<Self, bitcoin::consensus::encode::Error> {
137                #(#decode_temp_declarations)*
138                (|| -> core::result::Result<_, _> {
139                    lightning::decode_tlv_stream!(r, {
140                        #( #sorted_decode_entries )*
141                    });
142                    Ok(())
143                })()
144                    .map_err(|_e| bitcoin::consensus::encode::Error::ParseFailed(
145                        "decode_tlv_stream failed"))?;
146                Ok(Self { #(#decode_fields)* })
147            }
148        }
149    };
150
151    output.into()
152}
153
154fn unwrap_option(field_type: &Type) -> Option<&Type> {
155    if let syn::Type::Path(syn::TypePath { path, .. }) = &field_type {
156        if path.segments.len() == 1 && path.segments[0].ident == "Option" {
157            if let syn::PathArguments::AngleBracketed(args) = &path.segments[0].arguments {
158                if let Some(syn::GenericArgument::Type(ty)) = args.args.first() {
159                    return Some(ty);
160                }
161            }
162        }
163    }
164    None
165}
166
167#[proc_macro_derive(ReadMessage)]
168pub fn derive_read_message(input: TokenStream) -> TokenStream {
169    let DeriveInput { ident, data, .. } = parse_macro_input!(input);
170    let mut vs = Vec::new();
171    let mut ts = Vec::new();
172    let mut error: Option<Error> = None;
173
174    if let Data::Enum(DataEnum { variants, .. }) = data {
175        for v in variants {
176            if v.ident == "Unknown" {
177                continue;
178            }
179            let vident = v.ident.clone();
180            let field = extract_single_type(&vident, &v.fields);
181            match field {
182                Ok(f) => {
183                    vs.push(vident);
184                    ts.push(f);
185                }
186                Err(e) => match error.as_mut() {
187                    None => error = Some(e),
188                    Some(o) => o.combine(e),
189                },
190            }
191        }
192    } else {
193        unimplemented!()
194    }
195
196    if let Some(error) = error {
197        return error.into_compile_error().into();
198    }
199
200    let output = quote! {
201        impl #ident {
202            fn read_message<R: Read + ?Sized>(mut reader: &mut R, message_type: u16) -> Result<Message> {
203                let message = match message_type {
204                    #(#vs::TYPE => Message::#ts(Decodable::consensus_decode(reader)?)),*,
205                    _ => Message::Unknown(Unknown { message_type }),
206                };
207                Ok(message)
208            }
209
210            fn message_name(message_type: u16) -> &'static str {
211                match message_type {
212                    #(#vs::TYPE => stringify!(#vs)),*,
213                    _ => "Unknown",
214                }
215            }
216
217            pub fn inner(&self) -> alloc::boxed::Box<&dyn SerBolt> {
218                match self {
219                    #(#ident::#vs(inner) => alloc::boxed::Box::new(inner)),*,
220                    _ => alloc::boxed::Box::new(&UNKNOWN_PLACEHOLDER),
221                }
222            }
223        }
224    };
225
226    output.into()
227}
228
229fn extract_single_type(vident: &Ident, fields: &Fields) -> Result<TokenStream2, Error> {
230    let mut fields = fields.iter();
231    let field =
232        fields.next().ok_or_else(|| Error::new(vident.span(), "must have exactly one field"))?;
233    if fields.next().is_some() {
234        return Err(Error::new(vident.span(), "must have exactly one field"));
235    }
236    Ok(field.ty.clone().into_token_stream())
237}