1#![crate_name = "lightning_macros"]
11
12#![cfg_attr(not(test), no_std)]
15#![deny(missing_docs)]
16#![forbid(unsafe_code)]
17#![deny(rustdoc::broken_intra_doc_links)]
18#![deny(rustdoc::private_intra_doc_links)]
19#![cfg_attr(docsrs, feature(doc_cfg))]
20
21extern crate alloc;
22
23use alloc::string::ToString;
24use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
25use proc_macro2::TokenStream as TokenStream2;
26use quote::quote;
27use syn::spanned::Spanned;
28use syn::{parse, ImplItemFn, Token};
29use syn::{parse_macro_input, Item};
30
31fn add_async_method(mut parsed: ImplItemFn) -> TokenStream {
32 let output = quote! {
33 #[cfg(not(feature = "async-interface"))]
34 #parsed
35 };
36
37 parsed.sig.asyncness = Some(Token));
38
39 let output = quote! {
40 #output
41
42 #[cfg(feature = "async-interface")]
43 #parsed
44 };
45
46 output.into()
47}
48
49#[proc_macro_attribute]
51pub fn maybe_async(_attr: TokenStream, item: TokenStream) -> TokenStream {
52 if let Ok(parsed) = parse(item) {
53 add_async_method(parsed)
54 } else {
55 (quote! {
56 compile_error!("#[maybe_async] can only be used on methods")
57 })
58 .into()
59 }
60}
61
62#[proc_macro]
64pub fn maybe_await(expr: TokenStream) -> TokenStream {
65 let expr: proc_macro2::TokenStream = expr.into();
66 let quoted = quote! {
67 {
68 #[cfg(not(feature = "async-interface"))]
69 {
70 #expr
71 }
72
73 #[cfg(feature = "async-interface")]
74 {
75 #expr.await
76 }
77 }
78 };
79
80 quoted.into()
81}
82
83fn expect_ident(token: &TokenTree, expected_name: Option<&str>) {
84 if let TokenTree::Ident(id) = &token {
85 if let Some(exp) = expected_name {
86 assert_eq!(id.to_string(), exp, "Expected ident {}, got {:?}", exp, token);
87 }
88 } else {
89 panic!("Expected ident {:?}, got {:?}", expected_name, token);
90 }
91}
92
93fn expect_punct(token: &TokenTree, expected: char) {
94 if let TokenTree::Punct(p) = &token {
95 assert_eq!(p.as_char(), expected, "Expected punctuation {}, got {}", expected, p);
96 } else {
97 panic!("Expected punctuation {}, got {:?}", expected, token);
98 }
99}
100
101fn token_to_stream(token: TokenTree) -> proc_macro::TokenStream {
102 proc_macro::TokenStream::from(token)
103}
104
105fn process_fields(group: Group) -> proc_macro::TokenStream {
107 let mut computed_fields = proc_macro::TokenStream::new();
108 if group.delimiter() == Delimiter::Brace {
109 let mut fields_stream = group.stream().into_iter().peekable();
110
111 let mut new_fields = proc_macro::TokenStream::new();
112 loop {
113 let next_tok = fields_stream.peek();
115 if let Some(TokenTree::Punct(_)) = next_tok {
116 let dot1 = fields_stream.next().unwrap();
117 expect_punct(&dot1, '.');
118 let dot2 = fields_stream.next().expect("Missing second trailing .");
119 expect_punct(&dot2, '.');
120 let trailing_dots = [dot1, dot2];
121 new_fields.extend(trailing_dots.into_iter().map(token_to_stream));
122 assert!(fields_stream.peek().is_none());
123 break;
124 }
125
126 let ref_ident = fields_stream.next().unwrap();
130 expect_ident(&ref_ident, Some("ref"));
131 let field_name_ident = fields_stream.next().unwrap();
132 let co = fields_stream.next().unwrap();
133 expect_punct(&co, ':');
134 let ty_info = fields_stream.next().unwrap();
135 let com = fields_stream.next().unwrap();
136 expect_punct(&com, ',');
137
138 if let TokenTree::Group(group) = ty_info {
139 let first_group_tok = group.stream().into_iter().next().unwrap();
140 if let TokenTree::Ident(ident) = first_group_tok {
141 if ident.to_string() == "legacy" {
142 continue;
143 }
144 }
145 }
146
147 let field = [ref_ident, field_name_ident, com];
148 new_fields.extend(field.into_iter().map(token_to_stream));
149 }
150 let fields_group = Group::new(Delimiter::Brace, new_fields);
151 computed_fields.extend(token_to_stream(TokenTree::Group(fields_group)));
152 } else {
153 computed_fields.extend(token_to_stream(TokenTree::Group(group)));
154 }
155 computed_fields
156}
157
158#[proc_macro]
179pub fn skip_legacy_fields(expr: TokenStream) -> TokenStream {
180 let mut stream = expr.into_iter();
181 let mut res = TokenStream::new();
182
183 let match_ident = stream.next().unwrap();
185 expect_ident(&match_ident, Some("match"));
186 res.extend(proc_macro::TokenStream::from(match_ident));
187
188 let self_ident = stream.next().unwrap();
189 expect_ident(&self_ident, Some("self"));
190 res.extend(proc_macro::TokenStream::from(self_ident));
191
192 let arms = stream.next().unwrap();
193 if let TokenTree::Group(group) = arms {
194 let mut new_arms = TokenStream::new();
195
196 let mut arm_stream = group.stream().into_iter().peekable();
197 while arm_stream.peek().is_some() {
198 let enum_ident = arm_stream.next().unwrap();
201 let co1 = arm_stream.next().unwrap();
202 expect_punct(&co1, ':');
203 let co2 = arm_stream.next().unwrap();
204 expect_punct(&co2, ':');
205 let variant_ident = arm_stream.next().unwrap();
206 let fields = arm_stream.next().unwrap();
207 let eq = arm_stream.next().unwrap();
208 expect_punct(&eq, '=');
209 let gt = arm_stream.next().unwrap();
210 expect_punct(>, '>');
211 let init = arm_stream.next().unwrap();
212
213 let next_tok = arm_stream.peek();
214 if let Some(TokenTree::Punct(_)) = next_tok {
215 expect_punct(next_tok.unwrap(), ',');
216 arm_stream.next();
217 }
218
219 let computed_fields = if let TokenTree::Group(group) = fields {
220 process_fields(group)
221 } else {
222 panic!("Expected a group for the fields in a match arm");
223 };
224
225 let arm_pfx = [enum_ident, co1, co2, variant_ident];
226 new_arms.extend(arm_pfx.into_iter().map(token_to_stream));
227 new_arms.extend(computed_fields);
228 let arm_sfx = [eq, gt, init];
229 new_arms.extend(arm_sfx.into_iter().map(token_to_stream));
230 }
231
232 let new_arm_group = Group::new(Delimiter::Brace, new_arms);
233 res.extend(token_to_stream(TokenTree::Group(new_arm_group)));
234 } else {
235 panic!("Expected `match self {{..}}` and nothing else");
236 }
237
238 assert!(stream.next().is_none(), "Expected `match self {{..}}` and nothing else");
239
240 res
241}
242
243#[proc_macro]
257pub fn drop_legacy_field_definition(expr: TokenStream) -> TokenStream {
258 let mut st = if let Ok(parsed) = parse::<syn::Expr>(expr) {
259 if let syn::Expr::Struct(st) = parsed {
260 st
261 } else {
262 return (quote! {
263 compile_error!("drop_legacy_field_definition!() can only be used on struct expressions")
264 })
265 .into();
266 }
267 } else {
268 return (quote! {
269 compile_error!("drop_legacy_field_definition!() can only be used on expressions")
270 })
271 .into();
272 };
273 assert!(st.attrs.is_empty());
274 assert!(st.qself.is_none());
275 assert!(st.dot2_token.is_none());
276 assert!(st.rest.is_none());
277 let mut new_fields = syn::punctuated::Punctuated::new();
278 core::mem::swap(&mut new_fields, &mut st.fields);
279 for field in new_fields {
280 if let syn::Expr::Macro(syn::ExprMacro { mac, .. }) = &field.expr {
281 let macro_name = mac.path.segments.last().unwrap().ident.to_string();
282 let is_init = macro_name == "_init_tlv_based_struct_field";
283 let ty_tokens = mac.tokens.clone().into_iter().skip(2).next();
285 if let Some(proc_macro2::TokenTree::Group(group)) = ty_tokens {
286 let first_token = group.stream().into_iter().next();
287 if let Some(proc_macro2::TokenTree::Ident(ident)) = first_token {
288 if is_init && ident == "legacy" {
289 continue;
290 }
291 }
292 }
293 }
294 st.fields.push(field);
295 }
296 let out = syn::Expr::Struct(st);
297 quote! { #out }.into()
298}
299
300#[proc_macro_attribute]
318pub fn xtest(attrs: TokenStream, item: TokenStream) -> TokenStream {
319 let attrs = parse_macro_input!(attrs as TokenStream2);
320 let input = parse_macro_input!(item as Item);
321
322 let expanded = match input {
323 Item::Fn(item_fn) => {
324 let (cfg_attr, submit_attr) = if attrs.is_empty() {
325 (quote! { #[cfg_attr(test, test)] }, quote! { #[cfg(not(test))] })
326 } else {
327 (
328 quote! { #[cfg_attr(test, test)] #[cfg(any(test, #attrs))] },
329 quote! { #[cfg(all(not(test), #attrs))] },
330 )
331 };
332
333 if !item_fn.sig.inputs.is_empty()
335 || !matches!(item_fn.sig.output, syn::ReturnType::Default)
336 {
337 return syn::Error::new_spanned(
338 item_fn.sig,
339 "xtest functions must not take arguments and must return nothing",
340 )
341 .to_compile_error()
342 .into();
343 }
344
345 let should_panic =
347 item_fn.attrs.iter().any(|attr| attr.path().is_ident("should_panic"));
348
349 let fn_name = &item_fn.sig.ident;
350 let fn_name_str = fn_name.to_string();
351 quote! {
352 #cfg_attr
353 #item_fn
354
355 #submit_attr
357 inventory::submit! {
358 crate::XTestItem {
359 test_fn: #fn_name,
360 test_name: #fn_name_str,
361 should_panic: #should_panic,
362 }
363 }
364 }
365 },
366 _ => {
367 return syn::Error::new_spanned(
368 input,
369 "xtest can only be applied to functions or modules",
370 )
371 .to_compile_error()
372 .into();
373 },
374 };
375
376 TokenStream::from(expanded)
377}
378
379#[proc_macro]
383pub fn xtest_inventory(_input: TokenStream) -> TokenStream {
384 let expanded = quote! {
385 pub struct XTestItem {
387 pub test_fn: fn(),
388 pub test_name: &'static str,
389 pub should_panic: bool,
390 }
391
392 inventory::collect!(XTestItem);
393
394 pub fn get_xtests() -> Vec<&'static XTestItem> {
395 inventory::iter::<XTestItem>
396 .into_iter()
397 .collect()
398 }
399 };
400
401 TokenStream::from(expanded)
402}