Skip to main content

lightning_macros/
lib.rs

1// This file is Copyright its original authors, visible in version control
2// history.
3//
4// This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5// or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7// You may not use this file except in accordance with one or both of these
8// licenses.
9
10#![crate_name = "lightning_macros"]
11
12//! Proc macros used by LDK
13
14#![cfg_attr(not(test), no_std)]
15#![deny(missing_docs)]
16#![forbid(unsafe_code)]
17#![deny(rustdoc::broken_intra_doc_links)]
18#![deny(rustdoc::private_intra_doc_links)]
19#![cfg_attr(docsrs, feature(doc_cfg))]
20
21extern crate alloc;
22
23use alloc::string::ToString;
24use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
25use proc_macro2::TokenStream as TokenStream2;
26use quote::quote;
27use syn::spanned::Spanned;
28use syn::{parse, ImplItemFn, Token};
29use syn::{parse_macro_input, Item};
30
31fn add_async_method(mut parsed: ImplItemFn) -> TokenStream {
32	let output = quote! {
33		#[cfg(not(feature = "async-interface"))]
34		#parsed
35	};
36
37	parsed.sig.asyncness = Some(Token![async](parsed.span()));
38
39	let output = quote! {
40		#output
41
42		#[cfg(feature = "async-interface")]
43		#parsed
44	};
45
46	output.into()
47}
48
49/// Makes a method `async`, if the `async-interface` feature is enabled.
50#[proc_macro_attribute]
51pub fn maybe_async(_attr: TokenStream, item: TokenStream) -> TokenStream {
52	if let Ok(parsed) = parse(item) {
53		add_async_method(parsed)
54	} else {
55		(quote! {
56			compile_error!("#[maybe_async] can only be used on methods")
57		})
58		.into()
59	}
60}
61
62/// Awaits, if the `async-interface` feature is enabled.
63#[proc_macro]
64pub fn maybe_await(expr: TokenStream) -> TokenStream {
65	let expr: proc_macro2::TokenStream = expr.into();
66	let quoted = quote! {
67		{
68			#[cfg(not(feature = "async-interface"))]
69			{
70				#expr
71			}
72
73			#[cfg(feature = "async-interface")]
74			{
75				#expr.await
76			}
77		}
78	};
79
80	quoted.into()
81}
82
83fn expect_ident(token: &TokenTree, expected_name: Option<&str>) {
84	if let TokenTree::Ident(id) = &token {
85		if let Some(exp) = expected_name {
86			assert_eq!(id.to_string(), exp, "Expected ident {}, got {:?}", exp, token);
87		}
88	} else {
89		panic!("Expected ident {:?}, got {:?}", expected_name, token);
90	}
91}
92
93fn expect_punct(token: &TokenTree, expected: char) {
94	if let TokenTree::Punct(p) = &token {
95		assert_eq!(p.as_char(), expected, "Expected punctuation {}, got {}", expected, p);
96	} else {
97		panic!("Expected punctuation {}, got {:?}", expected, token);
98	}
99}
100
101fn token_to_stream(token: TokenTree) -> proc_macro::TokenStream {
102	proc_macro::TokenStream::from(token)
103}
104
105/// Processes a list of fields in a variant definition (see the docs for [`skip_legacy_fields!`])
106fn process_fields(group: Group) -> proc_macro::TokenStream {
107	let mut computed_fields = proc_macro::TokenStream::new();
108	if group.delimiter() == Delimiter::Brace {
109		let mut fields_stream = group.stream().into_iter().peekable();
110
111		let mut new_fields = proc_macro::TokenStream::new();
112		loop {
113			// The field list should end with .., at which point we break
114			let next_tok = fields_stream.peek();
115			if let Some(TokenTree::Punct(_)) = next_tok {
116				let dot1 = fields_stream.next().unwrap();
117				expect_punct(&dot1, '.');
118				let dot2 = fields_stream.next().expect("Missing second trailing .");
119				expect_punct(&dot2, '.');
120				let trailing_dots = [dot1, dot2];
121				new_fields.extend(trailing_dots.into_iter().map(token_to_stream));
122				assert!(fields_stream.peek().is_none());
123				break;
124			}
125
126			// Fields should take the form `ref field_name: ty_info` where `ty_info`
127			// may be a single ident or may be a group. We skip the field if `ty_info`
128			// is a group where the first token is the ident `legacy`.
129			let ref_ident = fields_stream.next().unwrap();
130			expect_ident(&ref_ident, Some("ref"));
131			let field_name_ident = fields_stream.next().unwrap();
132			let co = fields_stream.next().unwrap();
133			expect_punct(&co, ':');
134			let ty_info = fields_stream.next().unwrap();
135			let com = fields_stream.next().unwrap();
136			expect_punct(&com, ',');
137
138			if let TokenTree::Group(group) = ty_info {
139				let first_group_tok = group.stream().into_iter().next().unwrap();
140				if let TokenTree::Ident(ident) = first_group_tok {
141					if ident.to_string() == "legacy" {
142						continue;
143					}
144				}
145			}
146
147			let field = [ref_ident, field_name_ident, com];
148			new_fields.extend(field.into_iter().map(token_to_stream));
149		}
150		let fields_group = Group::new(Delimiter::Brace, new_fields);
151		computed_fields.extend(token_to_stream(TokenTree::Group(fields_group)));
152	} else {
153		computed_fields.extend(token_to_stream(TokenTree::Group(group)));
154	}
155	computed_fields
156}
157
158/// Scans a match statement for legacy fields which should be skipped.
159///
160/// This is used internally in LDK's TLV serialization logic and is not expected to be used by
161/// other crates.
162///
163/// Wraps a `match self {..}` statement and scans the fields in the match patterns (in the form
164/// `ref $field_name: $field_ty`) for types marked `legacy`, skipping those fields.
165///
166/// Specifically, it expects input like the following, simply dropping `field3` and the
167/// `: $field_ty` after each field name.
168/// ```ignore
169/// match self {
170///		Enum::Variant {
171///			ref field1: option,
172///			ref field2: (option, explicit_type: u64),
173///			ref field3: (legacy, u64, {}, {}), // will be skipped
174///			..
175///		} => expression
176///	}
177/// ```
178#[proc_macro]
179pub fn skip_legacy_fields(expr: TokenStream) -> TokenStream {
180	let mut stream = expr.into_iter();
181	let mut res = TokenStream::new();
182
183	// First expect `match self` followed by a `{}` group...
184	let match_ident = stream.next().unwrap();
185	expect_ident(&match_ident, Some("match"));
186	res.extend(proc_macro::TokenStream::from(match_ident));
187
188	let self_ident = stream.next().unwrap();
189	expect_ident(&self_ident, Some("self"));
190	res.extend(proc_macro::TokenStream::from(self_ident));
191
192	let arms = stream.next().unwrap();
193	if let TokenTree::Group(group) = arms {
194		let mut new_arms = TokenStream::new();
195
196		let mut arm_stream = group.stream().into_iter().peekable();
197		while arm_stream.peek().is_some() {
198			// Each arm should contain Enum::Variant { fields } => init
199			// We explicitly check the :s, =, and >, as well as an optional trailing ,
200			let enum_ident = arm_stream.next().unwrap();
201			let co1 = arm_stream.next().unwrap();
202			expect_punct(&co1, ':');
203			let co2 = arm_stream.next().unwrap();
204			expect_punct(&co2, ':');
205			let variant_ident = arm_stream.next().unwrap();
206			let fields = arm_stream.next().unwrap();
207			let eq = arm_stream.next().unwrap();
208			expect_punct(&eq, '=');
209			let gt = arm_stream.next().unwrap();
210			expect_punct(&gt, '>');
211			let init = arm_stream.next().unwrap();
212
213			let next_tok = arm_stream.peek();
214			if let Some(TokenTree::Punct(_)) = next_tok {
215				expect_punct(next_tok.unwrap(), ',');
216				arm_stream.next();
217			}
218
219			let computed_fields = if let TokenTree::Group(group) = fields {
220				process_fields(group)
221			} else {
222				panic!("Expected a group for the fields in a match arm");
223			};
224
225			let arm_pfx = [enum_ident, co1, co2, variant_ident];
226			new_arms.extend(arm_pfx.into_iter().map(token_to_stream));
227			new_arms.extend(computed_fields);
228			let arm_sfx = [eq, gt, init];
229			new_arms.extend(arm_sfx.into_iter().map(token_to_stream));
230		}
231
232		let new_arm_group = Group::new(Delimiter::Brace, new_arms);
233		res.extend(token_to_stream(TokenTree::Group(new_arm_group)));
234	} else {
235		panic!("Expected `match self {{..}}` and nothing else");
236	}
237
238	assert!(stream.next().is_none(), "Expected `match self {{..}}` and nothing else");
239
240	res
241}
242
243/// Scans an enum definition for fields initialized with `legacy` types and drops them.
244///
245/// This is used internally in LDK's TLV serialization logic and is not expected to be used by
246/// other crates.
247///
248/// Is expected to wrap a struct definition like
249/// ```ignore
250/// drop_legacy_field_definition!(Self {
251/// 	field1: $crate::_init_tlv_based_struct_field!(field1, option),
252/// 	field2: $crate::_init_tlv_based_struct_field!(field2, (legacy, u64, {})),
253/// })
254/// ```
255/// and will drop fields defined like `field2` with a type starting with `legacy`.
256#[proc_macro]
257pub fn drop_legacy_field_definition(expr: TokenStream) -> TokenStream {
258	let mut st = if let Ok(parsed) = parse::<syn::Expr>(expr) {
259		if let syn::Expr::Struct(st) = parsed {
260			st
261		} else {
262			return (quote! {
263				compile_error!("drop_legacy_field_definition!() can only be used on struct expressions")
264			})
265			.into();
266		}
267	} else {
268		return (quote! {
269			compile_error!("drop_legacy_field_definition!() can only be used on expressions")
270		})
271		.into();
272	};
273	assert!(st.attrs.is_empty());
274	assert!(st.qself.is_none());
275	assert!(st.dot2_token.is_none());
276	assert!(st.rest.is_none());
277	let mut new_fields = syn::punctuated::Punctuated::new();
278	core::mem::swap(&mut new_fields, &mut st.fields);
279	for field in new_fields {
280		if let syn::Expr::Macro(syn::ExprMacro { mac, .. }) = &field.expr {
281			let macro_name = mac.path.segments.last().unwrap().ident.to_string();
282			let is_init = macro_name == "_init_tlv_based_struct_field";
283			// Skip `field_name` and `:`, giving us just the type's group
284			let ty_tokens = mac.tokens.clone().into_iter().skip(2).next();
285			if let Some(proc_macro2::TokenTree::Group(group)) = ty_tokens {
286				let first_token = group.stream().into_iter().next();
287				if let Some(proc_macro2::TokenTree::Ident(ident)) = first_token {
288					if is_init && ident == "legacy" {
289						continue;
290					}
291				}
292			}
293		}
294		st.fields.push(field);
295	}
296	let out = syn::Expr::Struct(st);
297	quote! { #out }.into()
298}
299
300/// An exposed test.  This is a test that will run locally and also be
301/// made available to other crates that want to run it in their own context.
302///
303/// For example:
304/// ```rust
305/// use lightning_macros::xtest;
306///
307/// fn f1() {}
308///
309/// #[xtest(feature = "_externalize_tests")]
310/// pub fn test_f1() {
311///     f1();
312/// }
313/// ```
314///
315/// Which will include the module if we are testing or the `_test_utils` feature
316/// is on.
317#[proc_macro_attribute]
318pub fn xtest(attrs: TokenStream, item: TokenStream) -> TokenStream {
319	let attrs = parse_macro_input!(attrs as TokenStream2);
320	let input = parse_macro_input!(item as Item);
321
322	let expanded = match input {
323		Item::Fn(item_fn) => {
324			let (cfg_attr, submit_attr) = if attrs.is_empty() {
325				(quote! { #[cfg_attr(test, test)] }, quote! { #[cfg(not(test))] })
326			} else {
327				(
328					quote! { #[cfg_attr(test, test)] #[cfg(any(test, #attrs))] },
329					quote! { #[cfg(all(not(test), #attrs))] },
330				)
331			};
332
333			// Check that the function doesn't take args and returns nothing
334			if !item_fn.sig.inputs.is_empty()
335				|| !matches!(item_fn.sig.output, syn::ReturnType::Default)
336			{
337				return syn::Error::new_spanned(
338					item_fn.sig,
339					"xtest functions must not take arguments and must return nothing",
340				)
341				.to_compile_error()
342				.into();
343			}
344
345			// Check for #[should_panic] attribute
346			let should_panic =
347				item_fn.attrs.iter().any(|attr| attr.path().is_ident("should_panic"));
348
349			let fn_name = &item_fn.sig.ident;
350			let fn_name_str = fn_name.to_string();
351			quote! {
352				#cfg_attr
353				#item_fn
354
355				// We submit the test to the inventory only if we're not actually testing
356				#submit_attr
357				inventory::submit! {
358					crate::XTestItem {
359						test_fn: #fn_name,
360						test_name: #fn_name_str,
361						should_panic: #should_panic,
362					}
363				}
364			}
365		},
366		_ => {
367			return syn::Error::new_spanned(
368				input,
369				"xtest can only be applied to functions or modules",
370			)
371			.to_compile_error()
372			.into();
373		},
374	};
375
376	TokenStream::from(expanded)
377}
378
379/// Collects all externalized tests marked with `#[xtest]`
380/// into a vector of `XTestItem`s.  This vector can be
381/// retrieved by calling `get_xtests()`.
382#[proc_macro]
383pub fn xtest_inventory(_input: TokenStream) -> TokenStream {
384	let expanded = quote! {
385		/// An externalized test item, including the test function, name, and whether it is marked with `#[should_panic]`.
386		pub struct XTestItem {
387			pub test_fn: fn(),
388			pub test_name: &'static str,
389			pub should_panic: bool,
390		}
391
392		inventory::collect!(XTestItem);
393
394		pub fn get_xtests() -> Vec<&'static XTestItem> {
395			inventory::iter::<XTestItem>
396				.into_iter()
397				.collect()
398		}
399	};
400
401	TokenStream::from(expanded)
402}