Skip to main content

easy_macros_add_code/
lib.rs

1use anyhow::bail;
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{parse::Parse, parse::ParseStream, parse_quote, Block, Ident, Token};
5
6#[derive(Default)]
7struct AddCodeArgs {
8    before: Option<Vec<syn::Stmt>>,
9    after: Option<Vec<syn::Stmt>>,
10}
11
12impl Parse for AddCodeArgs {
13    fn parse(input: ParseStream) -> syn::Result<Self> {
14        let mut args = AddCodeArgs::default();
15
16        while !input.is_empty() {
17            let name: Ident = input.parse()?;
18            input.parse::<Token![=]>()?;
19
20            let block: Block = input.parse()?;
21            let stmts = block.stmts;
22
23            match name.to_string().as_str() {
24                "before" => {
25                    if args.before.is_some() {
26                        return Err(syn::Error::new(name.span(), "duplicate `before`"));
27                    }
28                    args.before = Some(stmts);
29                }
30                "after" => {
31                    if args.after.is_some() {
32                        return Err(syn::Error::new(name.span(), "duplicate `after`"));
33                    }
34                    args.after = Some(stmts);
35                }
36                _ => {
37                    return Err(syn::Error::new(name.span(), "expected `before` or `after`"));
38                }
39            }
40
41            if input.peek(Token![,]) {
42                input.parse::<Token![,]>()?;
43            }
44        }
45
46        if args.before.is_none() && args.after.is_none() {
47            return Err(syn::Error::new(input.span(), "missing `before` or `after`"));
48        }
49
50        Ok(args)
51    }
52}
53
54fn inject_into_block(block: &mut Block, args: &AddCodeArgs) {
55    if let Some(before) = &args.before {
56        let mut before_stmts = before.clone();
57        block.stmts.splice(0..0, before_stmts.drain(..));
58    }
59
60    if let Some(after) = &args.after {
61        let tail_expr = match block.stmts.last() {
62            Some(syn::Stmt::Expr(expr, None)) => Some(expr.clone()),
63            _ => None,
64        };
65
66        if let Some(expr) = tail_expr {
67            block.stmts.pop();
68            block
69                .stmts
70                .push(parse_quote! { let __add_code_result = { #expr }; });
71            block.stmts.extend(after.clone());
72            let tail_expr: syn::Expr = parse_quote! { __add_code_result };
73            block.stmts.push(syn::Stmt::Expr(tail_expr, None));
74        } else {
75            block.stmts.extend(after.clone());
76        }
77    }
78}
79
80#[proc_macro_attribute]
81#[anyhow_result::anyhow_result]
82/// Injects code at the beginning and/or end of a function-like item.
83///
84/// This is handy for keeping [docify](https://crates.io/crates/docify)-generated examples clean while still
85/// inserting setup/teardown or assertions for tests.
86///
87/// # Syntax
88/// ```rust,ignore
89/// #[add_code(before = { /* statements */ })]
90/// #[add_code(after = { /* statements */ })]
91/// #[add_code(before = { ... }, after = { ... })]
92/// fn demo() { /* ... */ }
93/// ```
94///
95/// The statements inside the braces are inserted without the outer `{}`.
96pub fn add_code(attr: TokenStream, item: TokenStream) -> anyhow::Result<TokenStream> {
97    let args = helpers::parse_macro_input!(attr as AddCodeArgs);
98
99    let item_ts: proc_macro2::TokenStream = item.clone().into();
100
101    if let Ok(mut item_fn) = syn::parse2::<syn::ItemFn>(item_ts.clone()) {
102        inject_into_block(&mut item_fn.block, &args);
103        return Ok(quote! { #item_fn }.into());
104    }
105
106    if let Ok(mut impl_fn) = syn::parse2::<syn::ImplItemFn>(item_ts) {
107        inject_into_block(&mut impl_fn.block, &args);
108        return Ok(quote! { #impl_fn }.into());
109    }
110
111    bail!("#[add_code] can only be used on functions or impl methods");
112}