bolt_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Ident, TokenStream as TokenStream2};
3use quote::{format_ident, quote, ToTokens};
4use syn::{
5    parse_macro_input, Data, DataEnum, DeriveInput, Error, Expr, Fields, GenericArgument, Lit,
6    LitInt, Meta, PathArguments, Type, TypePath,
7};
8
9/// Serialize a message with a type prefix, in BOLT style
10#[proc_macro_derive(SerBolt, attributes(message_id))]
11pub fn derive_ser_bolt(input: TokenStream) -> TokenStream {
12    let input1 = input.clone();
13    let DeriveInput { ident, attrs, .. } = parse_macro_input!(input1);
14    let message_id = attrs
15        .into_iter()
16        .find(|a| a.path().is_ident("message_id"))
17        .map(|a| {
18            let lit: LitInt = a.parse_args().expect("expected integer literal for message_id");
19            lit.to_token_stream()
20        })
21        .unwrap_or_else(|| {
22            Error::new(ident.span(), "missing message_id attribute").into_compile_error()
23        });
24
25    let output = quote! {
26        impl SerBolt for #ident {
27            fn as_vec(&self) -> Vec<u8> {
28                let message_type = Self::TYPE;
29                let mut buf = message_type.to_be_bytes().to_vec();
30                let mut val_buf = to_vec(&self).expect("serialize");
31                buf.append(&mut val_buf);
32                buf
33            }
34
35            fn name(&self) -> &'static str {
36                stringify!(#ident)
37            }
38        }
39
40        impl DeBolt for #ident {
41            const TYPE: u16 = #message_id;
42            fn from_vec(mut ser: Vec<u8>) -> Result<Self> {
43                let mut cursor = serde_bolt::io::Cursor::new(&ser);
44                let message_type = cursor.read_u16_be()?;
45                if message_type != Self::TYPE {
46                    return Err(Error::UnexpectedType(message_type));
47                }
48                let res = Decodable::consensus_decode(&mut cursor)?;
49                if cursor.position() as usize != ser.len() {
50                    return Err(Error::TrailingBytes(cursor.position() as usize - ser.len(), Self::TYPE));
51                }
52                Ok(res)
53            }
54        }
55    };
56    output.into()
57}
58
59#[proc_macro_derive(SerBoltTlvOptions, attributes(tlv_tag))]
60pub fn derive_ser_bolt_tlv(input: TokenStream) -> TokenStream {
61    let input = parse_macro_input!(input as DeriveInput);
62    let ident = &input.ident;
63
64    let mut encode_entries: Vec<(u64, proc_macro2::TokenStream)> = Vec::new();
65    let mut decode_entries: Vec<(u64, proc_macro2::TokenStream)> = Vec::new();
66    let mut decode_temp_declarations: Vec<proc_macro2::TokenStream> = Vec::new();
67    let mut decode_fields: Vec<proc_macro2::TokenStream> = Vec::new();
68
69    // traverse the fields, build the needed lists
70    if let Data::Struct(data_struct) = &input.data {
71        if let Fields::Named(fields_named) = &data_struct.fields {
72            for field in fields_named.named.iter() {
73                let field_name = field.ident.as_ref().unwrap();
74                let field_type = &field.ty;
75                let var_name = format_ident!("{}", field_name);
76
77                if let Some(attr) = field.attrs.iter().find(|a| a.path().is_ident("tlv_tag")) {
78                    match &attr.meta {
79                        Meta::NameValue(name_value) => {
80                            if let Expr::Lit(expr_lit) = &name_value.value {
81                                if let Lit::Int(lit_int) = &expr_lit.lit {
82                                    let tlv_tag = lit_int
83                                        .base10_parse::<u64>()
84                                        .expect("tlv_tag should be a valid u64");
85                                    encode_entries.push((
86                                        tlv_tag,
87                                        quote! {
88                                            (#tlv_tag, self.#var_name.as_ref().map(|f| crate::model::SerBoltTlvWriteWrap(f)), option),
89                                        },
90                                    ));
91                                    decode_entries.push((
92                                        tlv_tag,
93                                        quote! {
94                                            (#tlv_tag, #var_name, option),
95                                        },
96                                    ));
97                                    let inner_type =
98                                        unwrap_option(field_type).expect("Option type expected");
99                                    decode_temp_declarations.push(quote! {
100                                        let mut #var_name: Option<crate::model::SerBoltTlvReadWrap<#inner_type>> = None;
101                                    });
102                                    decode_fields.push(quote! {
103                                        #var_name: #var_name.map(|w| w.0),
104                                    });
105                                } else {
106                                    eprintln!("Warning: `tlv_tag` attribute value must be an integer literal.");
107                                }
108                            } else {
109                                eprintln!("Warning: `tlv_tag` attribute value is not a literal expression.");
110                            }
111                        }
112                        _ => eprintln!("Failed to parse `tlv_tag` attribute."),
113                    }
114                } else {
115                    eprintln!("Warning: Missing `tlv_tag` attribute for field `{}`.", field_name);
116                }
117            }
118        }
119    }
120
121    // sort the entries into ascending order
122    encode_entries.sort_by_key(|entry| entry.0);
123    decode_entries.sort_by_key(|entry| entry.0);
124    let sorted_encode_entries: Vec<_> = encode_entries.iter().map(|(_tag, ts)| ts).collect();
125    let sorted_decode_entries: Vec<_> = decode_entries.iter().map(|(_tag, ts)| ts).collect();
126
127    // generate the output
128    let output = quote! {
129        impl Encodable for #ident {
130            fn consensus_encode<W: bitcoin::io::Write + ?Sized>(
131                &self,
132                w: &mut W,
133            ) -> core::result::Result<usize, bitcoin::io::Error> {
134                let mut mw = crate::util::MeasuredWriter::wrap(w);
135                lightning::encode_tlv_stream!(&mut mw, {
136                    #( #sorted_encode_entries )*
137                });
138                Ok(mw.len())
139            }
140        }
141
142        impl Decodable for #ident {
143            fn consensus_decode<R: bitcoin::io::Read + ?Sized>(
144                r: &mut R,
145            ) -> core::result::Result<Self, bitcoin::consensus::encode::Error> {
146                #(#decode_temp_declarations)*
147                (|| -> core::result::Result<_, _> {
148                    // a sized reader is required, so wrap it in a Take
149                    let mut r = r.take(u64::MAX);
150                    lightning::decode_tlv_stream!(&mut r, {
151                        #( #sorted_decode_entries )*
152                    });
153                    Ok(())
154                })()
155                    .map_err(|_e| bitcoin::consensus::encode::Error::ParseFailed(
156                        "decode_tlv_stream failed"))?;
157                Ok(Self { #(#decode_fields)* })
158            }
159        }
160    };
161
162    output.into()
163}
164
165fn unwrap_option(field_type: &Type) -> Option<&Type> {
166    if let Type::Path(TypePath { path, .. }) = field_type {
167        if path.segments.len() == 1 && path.segments[0].ident == "Option" {
168            if let PathArguments::AngleBracketed(args) = &path.segments[0].arguments {
169                if let Some(GenericArgument::Type(ty)) = args.args.first() {
170                    return Some(ty);
171                }
172            }
173        }
174    }
175    None
176}
177
178#[proc_macro_derive(ReadMessage)]
179pub fn derive_read_message(input: TokenStream) -> TokenStream {
180    let DeriveInput { ident, data, .. } = parse_macro_input!(input);
181    let mut vs = Vec::new();
182    let mut ts = Vec::new();
183    let mut error: Option<Error> = None;
184
185    if let Data::Enum(DataEnum { variants, .. }) = data {
186        for v in variants {
187            if v.ident == "Unknown" {
188                continue;
189            }
190            let vident = v.ident.clone();
191            let field = extract_single_type(&vident, &v.fields);
192            match field {
193                Ok(f) => {
194                    vs.push(vident);
195                    ts.push(f);
196                }
197                Err(e) => match error.as_mut() {
198                    None => error = Some(e),
199                    Some(o) => o.combine(e),
200                },
201            }
202        }
203    } else {
204        unimplemented!()
205    }
206
207    if let Some(error) = error {
208        return error.into_compile_error().into();
209    }
210
211    let output = quote! {
212        impl #ident {
213            fn read_message<R: Read + ?Sized>(mut reader: &mut R, message_type: u16) -> Result<Message> {
214                let message = match message_type {
215                    #(#vs::TYPE => Message::#ts(Decodable::consensus_decode(reader)?)),*,
216                    _ => Message::Unknown(Unknown { message_type }),
217                };
218                Ok(message)
219            }
220
221            fn message_name(message_type: u16) -> &'static str {
222                match message_type {
223                    #(#vs::TYPE => stringify!(#vs)),*,
224                    _ => "Unknown",
225                }
226            }
227
228            pub fn inner(&self) -> alloc::boxed::Box<&dyn SerBolt> {
229                match self {
230                    #(#ident::#vs(inner) => alloc::boxed::Box::new(inner)),*,
231                    _ => alloc::boxed::Box::new(&UNKNOWN_PLACEHOLDER),
232                }
233            }
234        }
235    };
236
237    output.into()
238}
239
240fn extract_single_type(vident: &Ident, fields: &Fields) -> Result<TokenStream2, Error> {
241    let mut fields = fields.iter();
242    let field =
243        fields.next().ok_or_else(|| Error::new(vident.span(), "must have exactly one field"))?;
244    if fields.next().is_some() {
245        return Err(Error::new(vident.span(), "must have exactly one field"));
246    }
247    Ok(field.ty.clone().into_token_stream())
248}