1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
extern crate proc_macro;

#[macro_use]
extern crate quote;

use proc_macro::TokenStream;
use syn::parse;
use syn::punctuated::Punctuated;

#[proc_macro_attribute]
pub fn hook(attr: TokenStream, input: TokenStream) -> TokenStream {
    let hooks = syn::parse_macro_input!(attr as Hooks);
    let mut function = syn::parse_macro_input!(input as syn::ItemFn);

    let before: Box<syn::Block> = if let Some(func) = hooks.before {
        Box::new(syn::parse_quote!({ #func() }))
    } else {
        Box::new(syn::parse_quote!({}))
    };
    let after: Box<syn::Block> = if let Some(func) = hooks.after {
        Box::new(syn::parse_quote!({ #func() }))
    } else {
        Box::new(syn::parse_quote!({}))
    };

    let body = function.block;
    function.block = Box::new(syn::parse_quote!({
        #before { #body } #after
    }));
    TokenStream::from(quote!(#function))
}

struct Hooks {
    before: Option<syn::TypePath>,
    after: Option<syn::TypePath>,
}

mod pk {
    use super::*;

    syn::custom_keyword!(before);
    syn::custom_keyword!(after);

    pub enum Arg {
        Before {
            b: before,
            eq: syn::Token![=],
            func: syn::TypePath,
        },
        After {
            a: after,
            eq: syn::Token![=],
            func: syn::TypePath,
        },
    }

    impl parse::Parse for Arg {
        fn parse(input: parse::ParseStream) -> parse::Result<Self> {
            let lookahead = input.lookahead1();
            if lookahead.peek(before) {
                Ok(Arg::Before {
                    b: input.parse::<before>()?,
                    eq: input.parse::<syn::Token![=]>()?,
                    func: input.parse()?,
                })
            } else {
                Ok(Arg::After {
                    a: input.parse::<after>()?,
                    eq: input.parse::<syn::Token![=]>()?,
                    func: input.parse()?,
                })
            }
        }
    }
}

impl parse::Parse for Hooks {
    fn parse(input: parse::ParseStream) -> parse::Result<Self> {
        let mut hb = None;
        let mut ha = None;
        let parser = Punctuated::<pk::Arg, syn::Token![,]>::parse_terminated;
        if let Ok(args) = parser(input) {
            for arg in args.iter() {
                match arg {
                    pk::Arg::After { func, .. } => ha = Some(func.clone()),
                    pk::Arg::Before { func, .. } => hb = Some(func.clone()),
                }
            }
        }

        if hb.is_none() && ha.is_none() {
            hb = Some(input.parse::<syn::TypePath>()?);
        }

        Ok(Hooks {
            after: ha,
            before: hb,
        })
    }
}