1extern crate proc_macro;
2
3#[macro_use]
4extern crate quote;
5
6use proc_macro::TokenStream;
7use syn::parse;
8use syn::punctuated::Punctuated;
9
10#[proc_macro_attribute]
11pub fn hook(attr: TokenStream, input: TokenStream) -> TokenStream {
12 let hooks = syn::parse_macro_input!(attr as Hooks);
13 let mut function = syn::parse_macro_input!(input as syn::ItemFn);
14
15 let before: Box<syn::Block> = if let Some(func) = hooks.before {
16 Box::new(syn::parse_quote!({ #func() }))
17 } else {
18 Box::new(syn::parse_quote!({}))
19 };
20 let after: Box<syn::Block> = if let Some(func) = hooks.after {
21 Box::new(syn::parse_quote!({ #func() }))
22 } else {
23 Box::new(syn::parse_quote!({}))
24 };
25
26 let body = function.block;
27 function.block = Box::new(syn::parse_quote!({
28 #before { #body } #after
29 }));
30 TokenStream::from(quote!(#function))
31}
32
33struct Hooks {
34 before: Option<syn::TypePath>,
35 after: Option<syn::TypePath>,
36}
37
38mod pk {
39 use super::*;
40
41 syn::custom_keyword!(before);
42 syn::custom_keyword!(after);
43
44 pub enum Arg {
45 Before {
46 b: before,
47 eq: syn::Token![=],
48 func: syn::TypePath,
49 },
50 After {
51 a: after,
52 eq: syn::Token![=],
53 func: syn::TypePath,
54 },
55 }
56
57 impl parse::Parse for Arg {
58 fn parse(input: parse::ParseStream) -> parse::Result<Self> {
59 let lookahead = input.lookahead1();
60 if lookahead.peek(before) {
61 Ok(Arg::Before {
62 b: input.parse::<before>()?,
63 eq: input.parse::<syn::Token![=]>()?,
64 func: input.parse()?,
65 })
66 } else {
67 Ok(Arg::After {
68 a: input.parse::<after>()?,
69 eq: input.parse::<syn::Token![=]>()?,
70 func: input.parse()?,
71 })
72 }
73 }
74 }
75}
76
77impl parse::Parse for Hooks {
78 fn parse(input: parse::ParseStream) -> parse::Result<Self> {
79 let mut hb = None;
80 let mut ha = None;
81 let parser = Punctuated::<pk::Arg, syn::Token![,]>::parse_terminated;
82 if let Ok(args) = parser(input) {
83 for arg in args.iter() {
84 match arg {
85 pk::Arg::After { func, .. } => ha = Some(func.clone()),
86 pk::Arg::Before { func, .. } => hb = Some(func.clone()),
87 }
88 }
89 }
90
91 if hb.is_none() && ha.is_none() {
92 hb = Some(input.parse::<syn::TypePath>()?);
93 }
94
95 Ok(Hooks {
96 after: ha,
97 before: hb,
98 })
99 }
100}