1use proc_macro::TokenStream;
2use proc_macro2::{Ident, TokenStream as TokenStream2};
3use quote::{format_ident, quote, ToTokens};
4use syn::{
5 parse_macro_input, Data, DataEnum, DeriveInput, Error, Expr, Fields, GenericArgument, Lit,
6 LitInt, Meta, PathArguments, Type, TypePath,
7};
8
9#[proc_macro_derive(SerBolt, attributes(message_id))]
11pub fn derive_ser_bolt(input: TokenStream) -> TokenStream {
12 let input1 = input.clone();
13 let DeriveInput { ident, attrs, .. } = parse_macro_input!(input1);
14 let message_id = attrs
15 .into_iter()
16 .find(|a| a.path().is_ident("message_id"))
17 .map(|a| {
18 let lit: LitInt = a.parse_args().expect("expected integer literal for message_id");
19 lit.to_token_stream()
20 })
21 .unwrap_or_else(|| {
22 Error::new(ident.span(), "missing message_id attribute").into_compile_error()
23 });
24
25 let output = quote! {
26 impl SerBolt for #ident {
27 fn as_vec(&self) -> Vec<u8> {
28 let message_type = Self::TYPE;
29 let mut buf = message_type.to_be_bytes().to_vec();
30 let mut val_buf = to_vec(&self).expect("serialize");
31 buf.append(&mut val_buf);
32 buf
33 }
34
35 fn name(&self) -> &'static str {
36 stringify!(#ident)
37 }
38 }
39
40 impl DeBolt for #ident {
41 const TYPE: u16 = #message_id;
42 fn from_vec(mut ser: Vec<u8>) -> Result<Self> {
43 let mut cursor = serde_bolt::io::Cursor::new(&ser);
44 let message_type = cursor.read_u16_be()?;
45 if message_type != Self::TYPE {
46 return Err(Error::UnexpectedType(message_type));
47 }
48 let res = Decodable::consensus_decode(&mut cursor)?;
49 if cursor.position() as usize != ser.len() {
50 return Err(Error::TrailingBytes(cursor.position() as usize - ser.len(), Self::TYPE));
51 }
52 Ok(res)
53 }
54 }
55 };
56 output.into()
57}
58
59#[proc_macro_derive(SerBoltTlvOptions, attributes(tlv_tag))]
60pub fn derive_ser_bolt_tlv(input: TokenStream) -> TokenStream {
61 let input = parse_macro_input!(input as DeriveInput);
62 let ident = &input.ident;
63
64 let mut encode_entries: Vec<(u64, proc_macro2::TokenStream)> = Vec::new();
65 let mut decode_entries: Vec<(u64, proc_macro2::TokenStream)> = Vec::new();
66 let mut decode_temp_declarations: Vec<proc_macro2::TokenStream> = Vec::new();
67 let mut decode_fields: Vec<proc_macro2::TokenStream> = Vec::new();
68
69 if let Data::Struct(data_struct) = &input.data {
71 if let Fields::Named(fields_named) = &data_struct.fields {
72 for field in fields_named.named.iter() {
73 let field_name = field.ident.as_ref().unwrap();
74 let field_type = &field.ty;
75 let var_name = format_ident!("{}", field_name);
76
77 if let Some(attr) = field.attrs.iter().find(|a| a.path().is_ident("tlv_tag")) {
78 match &attr.meta {
79 Meta::NameValue(name_value) => {
80 if let Expr::Lit(expr_lit) = &name_value.value {
81 if let Lit::Int(lit_int) = &expr_lit.lit {
82 let tlv_tag = lit_int
83 .base10_parse::<u64>()
84 .expect("tlv_tag should be a valid u64");
85 encode_entries.push((
86 tlv_tag,
87 quote! {
88 (#tlv_tag, self.#var_name.as_ref().map(|f| crate::model::SerBoltTlvWriteWrap(f)), option),
89 },
90 ));
91 decode_entries.push((
92 tlv_tag,
93 quote! {
94 (#tlv_tag, #var_name, option),
95 },
96 ));
97 let inner_type =
98 unwrap_option(field_type).expect("Option type expected");
99 decode_temp_declarations.push(quote! {
100 let mut #var_name: Option<crate::model::SerBoltTlvReadWrap<#inner_type>> = None;
101 });
102 decode_fields.push(quote! {
103 #var_name: #var_name.map(|w| w.0),
104 });
105 } else {
106 eprintln!("Warning: `tlv_tag` attribute value must be an integer literal.");
107 }
108 } else {
109 eprintln!("Warning: `tlv_tag` attribute value is not a literal expression.");
110 }
111 }
112 _ => eprintln!("Failed to parse `tlv_tag` attribute."),
113 }
114 } else {
115 eprintln!("Warning: Missing `tlv_tag` attribute for field `{}`.", field_name);
116 }
117 }
118 }
119 }
120
121 encode_entries.sort_by_key(|entry| entry.0);
123 decode_entries.sort_by_key(|entry| entry.0);
124 let sorted_encode_entries: Vec<_> = encode_entries.iter().map(|(_tag, ts)| ts).collect();
125 let sorted_decode_entries: Vec<_> = decode_entries.iter().map(|(_tag, ts)| ts).collect();
126
127 let output = quote! {
129 impl Encodable for #ident {
130 fn consensus_encode<W: bitcoin::io::Write + ?Sized>(
131 &self,
132 w: &mut W,
133 ) -> core::result::Result<usize, bitcoin::io::Error> {
134 let mut mw = crate::util::MeasuredWriter::wrap(w);
135 lightning::encode_tlv_stream!(&mut mw, {
136 #( #sorted_encode_entries )*
137 });
138 Ok(mw.len())
139 }
140 }
141
142 impl Decodable for #ident {
143 fn consensus_decode<R: bitcoin::io::Read + ?Sized>(
144 r: &mut R,
145 ) -> core::result::Result<Self, bitcoin::consensus::encode::Error> {
146 #(#decode_temp_declarations)*
147 (|| -> core::result::Result<_, _> {
148 let mut r = r.take(u64::MAX);
150 lightning::decode_tlv_stream!(&mut r, {
151 #( #sorted_decode_entries )*
152 });
153 Ok(())
154 })()
155 .map_err(|_e| bitcoin::consensus::encode::Error::ParseFailed(
156 "decode_tlv_stream failed"))?;
157 Ok(Self { #(#decode_fields)* })
158 }
159 }
160 };
161
162 output.into()
163}
164
165fn unwrap_option(field_type: &Type) -> Option<&Type> {
166 if let Type::Path(TypePath { path, .. }) = field_type {
167 if path.segments.len() == 1 && path.segments[0].ident == "Option" {
168 if let PathArguments::AngleBracketed(args) = &path.segments[0].arguments {
169 if let Some(GenericArgument::Type(ty)) = args.args.first() {
170 return Some(ty);
171 }
172 }
173 }
174 }
175 None
176}
177
178#[proc_macro_derive(ReadMessage)]
179pub fn derive_read_message(input: TokenStream) -> TokenStream {
180 let DeriveInput { ident, data, .. } = parse_macro_input!(input);
181 let mut vs = Vec::new();
182 let mut ts = Vec::new();
183 let mut error: Option<Error> = None;
184
185 if let Data::Enum(DataEnum { variants, .. }) = data {
186 for v in variants {
187 if v.ident == "Unknown" {
188 continue;
189 }
190 let vident = v.ident.clone();
191 let field = extract_single_type(&vident, &v.fields);
192 match field {
193 Ok(f) => {
194 vs.push(vident);
195 ts.push(f);
196 }
197 Err(e) => match error.as_mut() {
198 None => error = Some(e),
199 Some(o) => o.combine(e),
200 },
201 }
202 }
203 } else {
204 unimplemented!()
205 }
206
207 if let Some(error) = error {
208 return error.into_compile_error().into();
209 }
210
211 let output = quote! {
212 impl #ident {
213 fn read_message<R: Read + ?Sized>(mut reader: &mut R, message_type: u16) -> Result<Message> {
214 let message = match message_type {
215 #(#vs::TYPE => Message::#ts(Decodable::consensus_decode(reader)?)),*,
216 _ => Message::Unknown(Unknown { message_type }),
217 };
218 Ok(message)
219 }
220
221 fn message_name(message_type: u16) -> &'static str {
222 match message_type {
223 #(#vs::TYPE => stringify!(#vs)),*,
224 _ => "Unknown",
225 }
226 }
227
228 pub fn inner(&self) -> alloc::boxed::Box<&dyn SerBolt> {
229 match self {
230 #(#ident::#vs(inner) => alloc::boxed::Box::new(inner)),*,
231 _ => alloc::boxed::Box::new(&UNKNOWN_PLACEHOLDER),
232 }
233 }
234 }
235 };
236
237 output.into()
238}
239
240fn extract_single_type(vident: &Ident, fields: &Fields) -> Result<TokenStream2, Error> {
241 let mut fields = fields.iter();
242 let field =
243 fields.next().ok_or_else(|| Error::new(vident.span(), "must have exactly one field"))?;
244 if fields.next().is_some() {
245 return Err(Error::new(vident.span(), "must have exactly one field"));
246 }
247 Ok(field.ty.clone().into_token_stream())
248}