1use proc_macro::TokenStream;
2use proc_macro2::{Ident, TokenStream as TokenStream2};
3use quote::{format_ident, quote, ToTokens};
4use syn::{parse_macro_input, Data, DataEnum, DeriveInput, Error, Fields, Lit, Type};
5
6#[proc_macro_derive(SerBolt, attributes(message_id))]
8pub fn derive_ser_bolt(input: TokenStream) -> TokenStream {
9 let input1 = input.clone();
10 let DeriveInput { ident, attrs, .. } = parse_macro_input!(input1);
11 let message_id = attrs
12 .into_iter()
13 .filter(|a| a.path.is_ident("message_id"))
14 .next()
15 .map(|a| a.tokens)
16 .unwrap_or_else(|| {
17 Error::new(ident.span(), "missing message_id attribute").into_compile_error()
18 });
19
20 let output = quote! {
21 impl SerBolt for #ident {
22 fn as_vec(&self) -> Vec<u8> {
23 let message_type = Self::TYPE;
24 let mut buf = message_type.to_be_bytes().to_vec();
25 let mut val_buf = to_vec(&self).expect("serialize");
26 buf.append(&mut val_buf);
27 buf
28 }
29
30 fn name(&self) -> &'static str {
31 stringify!(#ident)
32 }
33 }
34
35 impl DeBolt for #ident {
36 const TYPE: u16 = #message_id;
37 fn from_vec(mut ser: Vec<u8>) -> Result<Self> {
38 let mut cursor = serde_bolt::io::Cursor::new(&ser);
39 let message_type = cursor.read_u16_be()?;
40 if message_type != Self::TYPE {
41 return Err(Error::UnexpectedType(message_type));
42 }
43 let res = Decodable::consensus_decode(&mut cursor)?;
44 if cursor.position() as usize != ser.len() {
45 return Err(Error::TrailingBytes(cursor.position() as usize - ser.len(), Self::TYPE));
46 }
47 Ok(res)
48 }
49 }
50 };
51 output.into()
52}
53
54#[proc_macro_derive(SerBoltTlvOptions, attributes(tlv_tag))]
55pub fn derive_ser_bolt_tlv(input: TokenStream) -> TokenStream {
56 let input = parse_macro_input!(input as DeriveInput);
57 let ident = &input.ident;
58
59 let mut encode_entries: Vec<(u64, proc_macro2::TokenStream)> = Vec::new();
60 let mut decode_entries: Vec<(u64, proc_macro2::TokenStream)> = Vec::new();
61 let mut decode_temp_declarations: Vec<proc_macro2::TokenStream> = Vec::new();
62 let mut decode_fields: Vec<proc_macro2::TokenStream> = Vec::new();
63
64 if let Data::Struct(data_struct) = &input.data {
66 if let Fields::Named(fields_named) = &data_struct.fields {
67 for field in fields_named.named.iter() {
68 let field_name = field.ident.as_ref().unwrap();
69 let field_type = &field.ty;
70 let var_name = format_ident!("{}", field_name);
71
72 if let Some(attr) = field.attrs.iter().find(|a| a.path.is_ident("tlv_tag")) {
73 match attr.parse_meta() {
74 Ok(syn::Meta::NameValue(meta_name_value)) => {
75 if let Lit::Int(lit_int) = meta_name_value.lit {
76 let tlv_tag = lit_int
77 .base10_parse::<u64>()
78 .expect("tlv_tag should be a valid u64");
79 encode_entries.push((
80 tlv_tag,
81 quote! {
82 (#tlv_tag, self.#var_name.as_ref().map(|f| crate::model::SerBoltTlvWriteWrap(f)), option),
83 },
84 ));
85 decode_entries.push((
86 tlv_tag,
87 quote! {
88 (#tlv_tag, #var_name, option),
89 },
90 ));
91 let inner_type =
92 unwrap_option(field_type).expect("Option type expected");
93 decode_temp_declarations.push(quote! {
94 let mut #var_name: Option<crate::model::SerBoltTlvReadWrap<#inner_type>> = None;
95 });
96 decode_fields.push(quote! {
97 #var_name: #var_name.map(|w| w.0),
98 });
99 } else {
100 eprintln!("Warning: `tlv_tag` attribute value must be an integer.");
101 }
102 }
103 _ => eprintln!("Failed to parse `tlv_tag` attribute."),
104 }
105 } else {
106 eprintln!("Warning: Missing `tlv_tag` attribute for field `{}`.", field_name);
107 }
108 }
109 }
110 }
111
112 encode_entries.sort_by_key(|entry| entry.0);
114 decode_entries.sort_by_key(|entry| entry.0);
115 let sorted_encode_entries: Vec<_> = encode_entries.iter().map(|(_tag, ts)| ts).collect();
116 let sorted_decode_entries: Vec<_> = decode_entries.iter().map(|(_tag, ts)| ts).collect();
117
118 let output = quote! {
120 impl Encodable for #ident {
121 fn consensus_encode<W: io::Write + ?Sized>(
122 &self,
123 w: &mut W,
124 ) -> core::result::Result<usize, io::Error> {
125 let mut mw = crate::util::MeasuredWriter::wrap(w);
126 lightning::encode_tlv_stream!(&mut mw, {
127 #( #sorted_encode_entries )*
128 });
129 Ok(mw.len())
130 }
131 }
132
133 impl Decodable for #ident {
134 fn consensus_decode<R: io::Read + ?Sized>(
135 r: &mut R,
136 ) -> core::result::Result<Self, bitcoin::consensus::encode::Error> {
137 #(#decode_temp_declarations)*
138 (|| -> core::result::Result<_, _> {
139 lightning::decode_tlv_stream!(r, {
140 #( #sorted_decode_entries )*
141 });
142 Ok(())
143 })()
144 .map_err(|_e| bitcoin::consensus::encode::Error::ParseFailed(
145 "decode_tlv_stream failed"))?;
146 Ok(Self { #(#decode_fields)* })
147 }
148 }
149 };
150
151 output.into()
152}
153
154fn unwrap_option(field_type: &Type) -> Option<&Type> {
155 if let syn::Type::Path(syn::TypePath { path, .. }) = &field_type {
156 if path.segments.len() == 1 && path.segments[0].ident == "Option" {
157 if let syn::PathArguments::AngleBracketed(args) = &path.segments[0].arguments {
158 if let Some(syn::GenericArgument::Type(ty)) = args.args.first() {
159 return Some(ty);
160 }
161 }
162 }
163 }
164 None
165}
166
167#[proc_macro_derive(ReadMessage)]
168pub fn derive_read_message(input: TokenStream) -> TokenStream {
169 let DeriveInput { ident, data, .. } = parse_macro_input!(input);
170 let mut vs = Vec::new();
171 let mut ts = Vec::new();
172 let mut error: Option<Error> = None;
173
174 if let Data::Enum(DataEnum { variants, .. }) = data {
175 for v in variants {
176 if v.ident == "Unknown" {
177 continue;
178 }
179 let vident = v.ident.clone();
180 let field = extract_single_type(&vident, &v.fields);
181 match field {
182 Ok(f) => {
183 vs.push(vident);
184 ts.push(f);
185 }
186 Err(e) => match error.as_mut() {
187 None => error = Some(e),
188 Some(o) => o.combine(e),
189 },
190 }
191 }
192 } else {
193 unimplemented!()
194 }
195
196 if let Some(error) = error {
197 return error.into_compile_error().into();
198 }
199
200 let output = quote! {
201 impl #ident {
202 fn read_message<R: Read + ?Sized>(mut reader: &mut R, message_type: u16) -> Result<Message> {
203 let message = match message_type {
204 #(#vs::TYPE => Message::#ts(Decodable::consensus_decode(reader)?)),*,
205 _ => Message::Unknown(Unknown { message_type }),
206 };
207 Ok(message)
208 }
209
210 fn message_name(message_type: u16) -> &'static str {
211 match message_type {
212 #(#vs::TYPE => stringify!(#vs)),*,
213 _ => "Unknown",
214 }
215 }
216
217 pub fn inner(&self) -> alloc::boxed::Box<&dyn SerBolt> {
218 match self {
219 #(#ident::#vs(inner) => alloc::boxed::Box::new(inner)),*,
220 _ => alloc::boxed::Box::new(&UNKNOWN_PLACEHOLDER),
221 }
222 }
223 }
224 };
225
226 output.into()
227}
228
229fn extract_single_type(vident: &Ident, fields: &Fields) -> Result<TokenStream2, Error> {
230 let mut fields = fields.iter();
231 let field =
232 fields.next().ok_or_else(|| Error::new(vident.span(), "must have exactly one field"))?;
233 if fields.next().is_some() {
234 return Err(Error::new(vident.span(), "must have exactly one field"));
235 }
236 Ok(field.ty.clone().into_token_stream())
237}