racetrack_proc_macro/
lib.rs

1extern crate proc_macro;
2#[macro_use]
3extern crate syn;
4
5use proc_macro2::{Span, TokenStream};
6use quote::{quote, quote_spanned, ToTokens};
7use syn::{
8    punctuated::Punctuated, spanned::Spanned, AttributeArgs, Expr, ExprAssign, ExprClosure, FnArg,
9    ImplItem, ImplItemMethod, Index, Item, ItemFn, ItemImpl, Lit, Local, Meta, MetaNameValue,
10    NestedMeta, Pat, PatIdent, PatType, Stmt
11};
12
13#[inline]
14fn unsupported() -> TokenStream {
15    quote! {
16        compile_error!("Unsupported attribute target. 'track_with' only supports functions, impl blocks and closures.");
17    }
18}
19
20/// Track the target with the tracker specified in the arguments.
21/// Requires one argument containing the path to the tracker.
22///
23/// # Arguments
24///
25/// * `tracked_path` - The path to the tracker. This must be the first unnamed argument. Required.
26/// * `exclude` - A comma separated list of methods to exclude. This only does something on impl blocks.
27/// * `include_receiver` - Include the receiver (self). If false, the tracker must be available in the scope of the relevant method.
28///     If no receiver was found and this is true, the method will be skipped. Defaults to true.
29/// * `namespace` - Override the namespace of the tracked item. Tracked key will be namespace::function_name.
30///     Defaults to the struct name for impl blocks, None for functions and closures.
31///
32/// # Example
33///
34/// ```
35/// # use std::sync::Arc;
36/// use racetrack::{Tracker, track_with};
37///
38/// struct TrackedStruct(Arc<Tracker>);
39///
40/// #[track_with(0, namespace = "Tracked")]
41/// impl TrackedStruct {
42///     fn tracked_fn(&self, arg: String) {}
43/// }
44/// ```
45#[proc_macro_attribute]
46pub fn track_with(
47    args: proc_macro::TokenStream,
48    item_tokens: proc_macro::TokenStream
49) -> proc_macro::TokenStream {
50    let args = syn::parse_macro_input!(args as AttributeArgs);
51    let args = parse_args(args);
52    //println!("{:?}", args);
53
54    let item = syn::parse::<Item>(item_tokens.clone());
55
56    let tokens = match item {
57        Ok(Item::Fn(fun)) => track_function(&args, fun),
58        Ok(Item::Impl(item)) => track_impl(&args, item),
59        Ok(Item::Struct(_)) => quote! {
60            compile_error!("Structs aren't a supported attribute target. To track methods, put this attribute on an impl block.")
61        },
62        Err(_) => {
63            if let Ok(stmt) = syn::parse::<Stmt>(item_tokens.clone()) {
64                let tokens = match stmt {
65                    Stmt::Local(Local {
66                        pat, init, attrs, ..
67                    }) => {
68                        if let Some(Expr::Closure(closure)) = init.map(|expr| *expr.1) {
69                            let name = quote!(#pat).to_string();
70                            let closure = track_closure(&args, closure, name);
71                            quote! {
72                                #(#attrs)*
73                                let #pat = #closure;
74                            }
75                        } else {
76                            unsupported()
77                        }
78                    }
79                    Stmt::Expr(Expr::Assign(ExprAssign { left, right, .. })) => {
80                        if let Expr::Closure(closure) = *right {
81                            let name = quote!(#left).to_string();
82                            let closure = track_closure(&args, closure, name);
83                            quote!(#left = #closure)
84                        } else {
85                            unsupported()
86                        }
87                    }
88                    _ => unsupported()
89                };
90                tokens.into()
91            } else {
92                unsupported()
93            }
94        }
95        _ => unsupported()
96    };
97
98    tokens.into()
99}
100
101/// Arguments that can be passed to the proc macro
102#[derive(Debug)]
103struct Arguments {
104    /// The path to the tracker. This must be the first unnamed argument.
105    tracker_path: TokenStream,
106    /// A comma separated list of methods to exclude. This only does something on impl blocks.
107    exclude: Vec<String>,
108    /// Include the receiver (self). If false, the tracker must be available in the scope of the relevant method.
109    /// If no receiver was found and this is true, the method will be skipped. Defaults to true.
110    include_receiver: bool,
111    /// Override the namespace of the tracked item. Tracked key will be namespace::function_name.
112    /// Defaults to the struct name for impl blocks, None for functions and closures.
113    namespace: Option<String>
114}
115
116fn parse_args(mut args: AttributeArgs) -> Arguments {
117    args.reverse();
118    let tracker_path = {
119        if args.len() == 0 {
120            quote_spanned! {
121                Span::call_site() =>
122                compile_error!("Invalid number of arguments. Expected one argument with the path of the tracker.");
123            }
124        } else {
125            //println!("{:#?}", args);
126            let arg = args.pop().unwrap();
127            if let NestedMeta::Meta(Meta::Path(path)) = arg {
128                quote!(#path)
129            } else if let NestedMeta::Lit(Lit::Int(int)) = arg {
130                // Tuple struct ident
131                let value = int.base10_parse::<usize>().unwrap();
132                let index: Index = value.into();
133                quote!(#index)
134            } else {
135                quote_spanned! {
136                    arg.span() =>
137                    compile_error!("Invalid argument. Should be path of tracker.");
138                }
139            }
140        }
141    };
142    let mut arguments = Arguments {
143        tracker_path,
144        exclude: Vec::new(),
145        include_receiver: true,
146        namespace: None
147    };
148    while let Some(next) = args.pop() {
149        if let NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) = next {
150            if let Some(key) = path.segments.first().map(|path| path.ident.to_string()) {
151                match key.as_str() {
152                    "exclude" => {
153                        if let Lit::Str(str) = lit {
154                            let token = str.value();
155                            let value: Vec<_> =
156                                token.split(",").map(|s| s.trim().to_string()).collect();
157                            arguments.exclude = value;
158                        } else {
159                            panic!("Invalid value for exclude config. Should be comma separated string.");
160                        }
161                    }
162                    "include_receiver" => {
163                        if let Lit::Bool(bool) = lit {
164                            arguments.include_receiver = bool.value;
165                        } else {
166                            panic!("Invalid value for include_receiver config. Should be boolean.");
167                        }
168                    }
169                    "namespace" => {
170                        if let Lit::Str(str) = lit {
171                            arguments.namespace = Some(str.value());
172                        } else {
173                            panic!("Invalid value for namespace config. Should be a string.");
174                        }
175                    }
176                    _ => {
177                        panic!("Unexpected config entry in track_with attribute.");
178                    }
179                }
180            } else {
181                panic!("Invalid config entry in track_with attribute.");
182            }
183        } else {
184            panic!("Unexpected argument in track_with attribute.");
185        }
186    }
187    //println!("{:?}", arguments);
188    arguments
189}
190
191fn track_impl(args: &Arguments, item: ItemImpl) -> TokenStream {
192    //println!("{:#?}", item);
193    let ItemImpl {
194        attrs,
195        defaultness,
196        unsafety,
197        generics,
198        trait_,
199        self_ty,
200        items,
201        ..
202    } = item;
203    let namespace = args
204        .namespace
205        .as_ref()
206        .map(|s| s.clone())
207        .unwrap_or_else(|| quote!(#self_ty).to_string());
208    let trait_ = trait_.map(|(bang, trait_, for_)| quote!(#bang#trait_ #for_));
209
210    let items = items.iter().map(|item| {
211        if let ImplItem::Method(method) = item {
212            track_method(&args, method, &namespace)
213        } else {
214            quote!(#item)
215        }
216    });
217
218    let tokens = quote! {
219        #(#attrs)*
220        #defaultness #unsafety impl #generics #trait_ #self_ty {
221            #(#items)*
222        }
223    };
224
225    //println!("{}", tokens);
226    tokens
227}
228
229fn track_method(args: &Arguments, method: &ImplItemMethod, namespace: &str) -> TokenStream {
230    let name = method.sig.ident.to_string();
231    if args.exclude.contains(&name) {
232        return quote!(#method);
233    }
234    let name = format!("{}::{}", namespace, name);
235
236    let ImplItemMethod {
237        attrs,
238        vis,
239        defaultness,
240        sig,
241        block
242    } = method;
243
244    let receiver = sig.inputs.iter().find_map(|arg| {
245        if let FnArg::Receiver(recv) = arg {
246            Some(recv)
247        } else {
248            None
249        }
250    });
251
252    if args.include_receiver && receiver.is_none() {
253        // Skip static methods since the tracker path won't be valid
254        return quote!(#method);
255    }
256
257    let inputs_cloned = cloned_inputs(&sig.inputs);
258    let result_cloned = quote_spanned! {
259        sig.output.span() =>
260        returned.to_owned()
261    };
262    let statements = &block.stmts;
263    let tracker_path = &args.tracker_path;
264    let tracker_path = if args.include_receiver {
265        quote!(self.#tracker_path)
266    } else {
267        tracker_path.clone()
268    };
269
270    let body = quote_spanned! {
271        block.span() =>
272        let args = (#(#inputs_cloned),*);
273        let returned = {
274            #(#statements)*
275        };
276        #tracker_path.log_call(#name, ::racetrack::CallInfo {
277            arguments: Some(Box::new(args)),
278            returned: Some(Box::new(#result_cloned))
279        });
280        returned
281    };
282
283    let attrs = spanned_vec(attrs);
284    let vis = spanned(vis);
285    let defaultness = spanned_opt(defaultness.as_ref());
286    let sig = spanned(sig);
287
288    let tokens = quote! {
289        #(#attrs)*
290        #vis #defaultness #sig {
291            #body
292        }
293    };
294
295    tokens
296}
297
298fn track_function(args: &Arguments, fun: ItemFn) -> TokenStream {
299    //println!("{:#?}", fun);
300    let attrs = fun.attrs;
301    let visibility = fun.vis;
302    let signature = fun.sig;
303    let name = if let Some(ref namespace) = args.namespace {
304        format!("{}::{}", namespace, signature.ident.to_string())
305    } else {
306        signature.ident.to_string()
307    };
308    let arg_idents = cloned_inputs(&signature.inputs);
309    let returned_clone = quote_spanned! {
310        signature.output.span() =>
311        returned.to_owned()
312    };
313    let block = &fun.block;
314    let statements = &fun.block.stmts;
315    let tracker_path = &args.tracker_path;
316    let body = quote_spanned! {
317        block.span() =>
318            let args = (#(#arg_idents),*);
319            let returned = {
320                #(#statements)*
321            };
322            #tracker_path.log_call(#name, ::racetrack::CallInfo {
323                arguments: Some(Box::new(args)),
324                returned: Some(Box::new(#returned_clone))
325            });
326            returned
327    };
328
329    let tokens = quote! {
330        #(#attrs)*
331        #visibility #signature {
332            #body
333        }
334    };
335
336    //println!("{}", tokens);
337    tokens
338}
339
340fn track_closure(args: &Arguments, closure: ExprClosure, name: String) -> TokenStream {
341    let ExprClosure {
342        attrs,
343        asyncness,
344        movability,
345        capture,
346        inputs,
347        output,
348        body,
349        ..
350    } = closure;
351    let tracker_path = &args.tracker_path;
352    let attrs = spanned_vec(&attrs);
353    let asyncness = spanned_opt(asyncness);
354    let movability = spanned_opt(movability);
355    let capture = spanned_opt(capture);
356    let cloned_inputs = cloned_inputs_pat(&inputs);
357    let cloned_return = quote_spanned! {
358        output.span() =>
359        returned.to_owned()
360    };
361    let inputs: Vec<_> = inputs.iter().map(|input| {
362        quote_spanned! {
363            input.span() =>
364            #input
365        }
366    }).collect();
367    let arguments = &inputs;
368    let body_outer = quote_spanned! {
369        body.span() =>
370        let args = (#(#cloned_inputs),*);
371        let returned = inner(#(#arguments)*);
372        tracker.log_call(#name, ::racetrack::CallInfo {
373            arguments: Some(Box::new(args)),
374            returned: Some(Box::new(#cloned_return))
375        });
376        returned
377    };
378
379    let tokens = quote! {
380        {
381            let inner = #(#attrs)*
382            #asyncness #movability #capture |#(#arguments)*| #output {
383                #body
384            };
385            let tracker = #tracker_path.clone();
386            #asyncness #movability move |#(#arguments)*| #output {
387                #body_outer
388            }
389        }
390    };
391    tokens
392}
393
394fn spanned(item: impl ToTokens + Spanned) -> TokenStream {
395    quote_spanned! {
396        item.span() =>
397        #item
398    }
399}
400
401fn spanned_vec<T: ToTokens + Spanned>(item: &Vec<T>) -> Vec<TokenStream> {
402    item.iter()
403        .map(|item| {
404            quote_spanned! {
405                item.span() =>
406                #item
407            }
408        })
409        .collect()
410}
411
412fn spanned_opt<T: ToTokens + Spanned>(item: Option<T>) -> TokenStream {
413    item.map(|item| {
414        quote_spanned! {
415            item.span() =>
416            #item
417        }
418    })
419    .unwrap_or_else(|| quote!())
420}
421
422fn cloned_inputs<'a>(inputs: &Punctuated<FnArg, Token![,]>) -> Vec<TokenStream> {
423    inputs
424        .iter()
425        .filter_map(|arg| {
426            if let FnArg::Typed(PatType { ref pat, .. }) = arg {
427                Some(pat)
428            } else {
429                None
430            }
431        })
432        .filter_map(|arg| {
433            if let &Pat::Ident(PatIdent { ref ident, .. }) = &**arg {
434                Some(ident)
435            } else {
436                None
437            }
438        })
439        .map(|ident| {
440            quote_spanned! {
441                ident.span() =>
442                #ident.to_owned()
443            }
444        })
445        .collect()
446}
447
448fn cloned_inputs_pat<'a>(inputs: &Punctuated<Pat, Token![,]>) -> Vec<TokenStream> {
449    //println!("{:?}", inputs);
450    inputs
451        .iter()
452        .filter_map(|arg| {
453            if let Pat::Ident(PatIdent { ref ident, .. }) = arg {
454                Some(ident)
455            } else if let Pat::Type(PatType { pat, .. }) = arg {
456                if let Pat::Ident(PatIdent { ident, .. }) = &**pat {
457                    Some(ident)
458                } else {
459                    None
460                }
461            } else {
462                None
463            }
464        })
465        .map(|ident| {
466            quote_spanned! {
467                ident.span() =>
468                #ident.to_owned()
469            }
470        })
471        .collect()
472}