arpa_log_impl/
lib.rs

1#![feature(box_patterns)]
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, TokenStream as TokenStream2};
5use quote::quote;
6use syn::{
7    fold::{self, Fold},
8    parse_macro_input, parse_quote,
9    punctuated::Punctuated,
10    spanned::Spanned,
11    token::{Comma, Semi},
12    AttributeArgs, Expr, FnArg, GenericParam, ItemFn, Lit, NestedMeta, Pat, PatType, Stmt,
13};
14
15struct FunctionLogVisitor {
16    name: Ident,
17    show_input: bool,
18    ignore_input_args: Vec<String>,
19    show_return: bool,
20    /// support async function by `#[async_trait]`
21    async_trait: bool,
22    /// we don't want add log before returning in sub block or sub closure/async block
23    current_block_count: i32,
24    current_closure_or_async_block_count: i32,
25    /// deal with no explicitly return stmt for () as return type
26    has_return_stmt: bool,
27}
28
29macro_rules! macro_error {
30    ($msg:literal) => {
31        quote::quote! {
32            compile_error!($msg);
33        }
34        .into()
35    };
36
37    ($msg:literal, $span:expr) => {
38        quote::quote_spanned! { $span =>
39            compile_error!($msg);
40        }
41        .into()
42    };
43}
44
45#[proc_macro_attribute]
46pub fn log_function(attr: TokenStream, input: TokenStream) -> TokenStream {
47    // ItemFn seems OK for impl function perhaps for they both have sig and block.
48    let fn_decl = parse_macro_input!(input as ItemFn);
49    let fn_ident = fn_decl.sig.ident.clone();
50    let fn_args = fn_decl.sig.inputs.clone();
51    let fn_sig = &fn_decl.sig;
52    let fn_stmts = &fn_decl.block.stmts;
53
54    let fn_async_trait = fn_decl.sig.generics.params.iter().any(|p| match p {
55        GenericParam::Lifetime(x) => x.lifetime.ident == "async_trait",
56        _ => false,
57    });
58
59    let args = parse_macro_input!(attr as AttributeArgs);
60
61    let mut show_input = false;
62    let mut show_return = false;
63    let mut ignore_input_args = vec![];
64
65    for arg in args {
66        match arg {
67            NestedMeta::Lit(Lit::Str(x)) if x.token().to_string() == "\"show-input\"" => {
68                show_input = true;
69            }
70            NestedMeta::Lit(Lit::Str(x)) if x.token().to_string() == "\"show-return\"" => {
71                show_return = true;
72            }
73            NestedMeta::Lit(Lit::Str(x)) if x.token().to_string().starts_with("\"except") => {
74                ignore_input_args = x
75                    .token()
76                    .to_string()
77                    .trim_end_matches('\"')
78                    .split_whitespace()
79                    .skip(1)
80                    .filter_map(|word| word.parse().ok())
81                    .collect();
82            }
83            _ => {
84                return macro_error!("unknown logging options", arg.span());
85            }
86        }
87    }
88
89    let mut visitor = FunctionLogVisitor {
90        name: fn_ident.clone(),
91        show_input,
92        ignore_input_args,
93        show_return,
94        async_trait: fn_async_trait,
95        current_block_count: 0,
96        current_closure_or_async_block_count: 0,
97        has_return_stmt: false,
98    };
99
100    let args_text = visitor.generate_args_text(fn_args);
101
102    // Use a syntax tree traversal to transform the function body.
103    let stmts: Punctuated<Stmt, Semi> = fn_stmts
104        .iter()
105        .map(|stmt| visitor.fold_stmt(stmt.to_owned()))
106        .collect();
107
108    let post_code = if visitor.has_return_stmt {
109        TokenStream2::new()
110    } else {
111        let log = visitor.generate_log();
112        quote! {
113            let __res = "nothing";
114            #log
115            return;
116        }
117    };
118
119    quote! {
120        #fn_sig {
121            log_mdc::insert("fn_name", stringify!(#fn_ident));
122            #args_text
123            #stmts
124            #post_code
125        }
126    }
127    .into()
128}
129
130impl FunctionLogVisitor {
131    fn generate_args_text(&self, fn_args: Punctuated<FnArg, Comma>) -> TokenStream2 {
132        let args = quote! {
133            let mut __args: Vec<String> = vec![];
134        };
135
136        fn_args
137            .iter()
138            .filter_map(|arg| match arg {
139                FnArg::Typed(PatType {
140                    attrs: _,
141                    pat: box Pat::Ident(p),
142                    colon_token: _,
143                    ty: _,
144                }) => {
145                    let ident = &p.ident;
146                    let arg_text = if self.show_input
147                        && !self.ignore_input_args.contains(&ident.to_string())
148                    {
149                        quote! {
150                            __args.push(format!("{}: {:?}", stringify!(#ident), #ident));
151                        }
152                    } else {
153                        quote! {
154                            __args.push(format!("{}: ignored", stringify!(#ident)));
155                        }
156                    };
157                    Some(arg_text)
158                }
159                _ => None,
160            })
161            .fold(args, |mut args, arg| {
162                args.extend(arg);
163                args
164            })
165    }
166
167    fn generate_log(&self) -> TokenStream2 {
168        let fn_ident = &self.name;
169
170        let return_text = if self.show_return {
171            quote! {
172                &format!("{:?}", __res)
173            }
174        } else {
175            quote! {
176                &format!("{:?}", "ignored")
177            }
178        };
179        quote! {
180            let ___args: Vec<&str> = __args.iter().map(|arg| arg as &str).collect();
181            let __log = LogModel{
182                fn_name: stringify!(#fn_ident),
183                fn_args: &___args,
184                fn_return: #return_text,
185            };
186            debug!(target: stringify!(#fn_ident), "{:?}", __log);
187            log_mdc::remove("fn_name");
188        }
189    }
190
191    fn handle_expr_try(&mut self, e: syn::ExprTry) -> TokenStream2 {
192        let expr = fold::fold_expr(self, *e.expr);
193        let log = self.generate_log();
194        quote!(
195            match #expr {
196                Ok(v) => v,
197                Err(e) => {
198                    let __res = Err(e.into());
199                    #log
200                    return __res;
201                }
202            }
203        )
204    }
205
206    fn insert_log_and_fold_expr_stmt(&mut self, e: Expr) -> Stmt {
207        if !self.async_trait && self.current_block_count == 0 {
208            self.has_return_stmt = true;
209            let log = self.generate_log();
210            let expr = fold::fold_expr(self, e);
211            parse_quote!({
212                let __res = #expr;
213                #log
214                __res
215            })
216        } else {
217            fold::fold_stmt(self, Stmt::Expr(e))
218        }
219    }
220}
221
222impl Fold for FunctionLogVisitor {
223    fn fold_block(&mut self, i: syn::Block) -> syn::Block {
224        self.current_block_count += 1;
225        let res = fold::fold_block(self, i);
226        self.current_block_count -= 1;
227        res
228    }
229
230    fn fold_expr(&mut self, e: Expr) -> Expr {
231        match e {
232            Expr::Block(_) => {
233                self.current_block_count += 1;
234                let res = fold::fold_expr(self, e);
235                self.current_block_count -= 1;
236                res
237            }
238            Expr::Return(e) => {
239                if self.current_closure_or_async_block_count == 0 {
240                    self.has_return_stmt = true;
241                    let log = self.generate_log();
242                    if let Some(v) = e.expr {
243                        let expr = fold::fold_expr(self, *v);
244                        parse_quote!({
245                            let __res = #expr;
246                            #log
247                            return __res;
248                        })
249                    } else {
250                        parse_quote!({
251                            let __res = "nothing";
252                            #log
253                            return;
254                        })
255                    }
256                } else {
257                    fold::fold_expr(self, Expr::Return(e))
258                }
259            }
260            Expr::Try(e) => {
261                let expr_try = self.handle_expr_try(e);
262                parse_quote!(
263                    #expr_try
264                )
265            }
266            // clone __args before move block
267            Expr::Async(e) => {
268                if e.capture.is_some() {
269                    self.current_closure_or_async_block_count += 1;
270                    let expr = fold::fold_expr_async(self, e);
271                    self.current_closure_or_async_block_count -= 1;
272                    parse_quote!({
273                        let __args = __args.clone();
274                        #expr
275                    })
276                } else {
277                    fold::fold_expr(self, Expr::Async(e))
278                }
279            }
280            // clone __args before move block
281            Expr::Closure(e) => {
282                if e.capture.is_some() {
283                    self.current_closure_or_async_block_count += 1;
284                    let expr = fold::fold_expr_closure(self, e);
285                    self.current_closure_or_async_block_count -= 1;
286                    parse_quote!({
287                        let __args = __args.clone();
288                        #expr
289                    })
290                } else {
291                    fold::fold_expr(self, Expr::Closure(e))
292                }
293            }
294            _ => fold::fold_expr(self, e),
295        }
296    }
297
298    fn fold_stmt(&mut self, s: Stmt) -> Stmt {
299        match s {
300            Stmt::Expr(e) => match e {
301                // ignore log on Box::pin in async_trait attribute macro
302                Expr::Call(c) => match *c.func.clone() {
303                    Expr::Path(p) => {
304                        let first = p.path.segments.first();
305                        let last = p.path.segments.last();
306                        match (self.async_trait, first, last) {
307                            (true, Some(f), Some(l)) if f.ident == "Box" && l.ident == "pin" => {
308                                fold::fold_stmt(self, Stmt::Expr(Expr::Call(c)))
309                            }
310                            _ => self.insert_log_and_fold_expr_stmt(Expr::Call(c)),
311                        }
312                    }
313                    _ => self.insert_log_and_fold_expr_stmt(Expr::Call(c)),
314                },
315                // log on __ret in async_trait attribute macro
316                Expr::Path(p) => {
317                    let ident = p.path.get_ident();
318                    if self.async_trait && ident.is_some() && *ident.unwrap() == "__ret" {
319                        self.has_return_stmt = true;
320                        let log = self.generate_log();
321                        parse_quote!({
322                            let __res = __ret;
323                            #log
324                            __res
325                        })
326                    } else {
327                        self.insert_log_and_fold_expr_stmt(Expr::Path(p))
328                    }
329                }
330                // These exprs should be common for return value.
331                Expr::Array(_)
332                | Expr::Await(_)
333                | Expr::Binary(_)
334                | Expr::Closure(_)
335                | Expr::Cast(_)
336                | Expr::Field(_)
337                | Expr::Index(_)
338                | Expr::If(_)
339                | Expr::Lit(_)
340                | Expr::Macro(_)
341                | Expr::MethodCall(_)
342                | Expr::Match(_)
343                | Expr::Paren(_)
344                | Expr::Range(_)
345                | Expr::Return(_)
346                | Expr::Reference(_)
347                | Expr::Repeat(_)
348                | Expr::Struct(_)
349                | Expr::Tuple(_)
350                | Expr::Unary(_) => self.insert_log_and_fold_expr_stmt(e),
351                Expr::Block(_) => {
352                    self.current_block_count += 1;
353                    let res = fold::fold_stmt(self, Stmt::Expr(e));
354                    self.current_block_count -= 1;
355                    res
356                }
357
358                _ => fold::fold_stmt(self, Stmt::Expr(e)),
359            },
360            Stmt::Semi(e, semi) => match e {
361                Expr::Try(e) => {
362                    let expr_try = self.handle_expr_try(e);
363                    parse_quote!(
364                        #expr_try;
365                    )
366                }
367                _ => fold::fold_stmt(self, Stmt::Semi(e, semi)),
368            },
369            _ => fold::fold_stmt(self, s),
370        }
371    }
372}