1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
extern crate proc_macro;

use proc_macro::TokenStream;
use quote::quote;
use syn::{
    parse::{Parse, ParseStream},
    parse_macro_input,
    punctuated::Punctuated,
    FnArg, Ident, ItemFn, Pat, PatIdent, Result, Token, Type, TypePath, TypeReference,
};

#[derive(Debug, Clone)]
struct IdentList {
    values: Punctuated<Ident, Token![,]>,
}

impl Parse for IdentList {
    fn parse(input: ParseStream) -> Result<Self> {
        Ok(Self {
            values: input.parse_terminated(Ident::parse)?,
        })
    }
}

fn unpack_context(ty: &Type, pat: &Pat) -> Option<Ident> {
    match ty {
        Type::Path(TypePath { path, .. }) => {
            if let Some(segment) = path.segments.iter().last() {
                if segment.ident == "WidgetContext" {
                    if let Pat::Ident(PatIdent { ident, .. }) = pat {
                        return Some(ident.to_owned());
                    }
                }
            }
        }
        Type::Reference(TypeReference { elem, .. }) => {
            return unpack_context(&**elem, pat);
        }
        _ => {}
    }
    None
}

fn is_arg_context(arg: &FnArg) -> Option<Ident> {
    if let FnArg::Typed(pat) = arg {
        unpack_context(&*pat.ty, &*pat.pat)
    } else {
        None
    }
}

#[proc_macro_attribute]
pub fn pre_hooks(attr: TokenStream, input: TokenStream) -> TokenStream {
    let ItemFn {
        attrs,
        vis,
        sig,
        block,
    } = parse_macro_input!(input as ItemFn);
    let context = sig
        .inputs
        .iter()
        .find_map(is_arg_context)
        .unwrap_or_else(|| panic!("Could not find function context argument!"));
    let list = parse_macro_input!(attr as IdentList);
    let hooks = list
        .values
        .into_iter()
        .map(|v| quote! { #context.use_hook(#v); });

    let tokens = quote! {
        #(#attrs)*
        #vis #sig {
            #(#hooks)*
            #block
        }
    };
    tokens.into()
}

#[proc_macro_attribute]
pub fn post_hooks(attr: TokenStream, input: TokenStream) -> TokenStream {
    let ItemFn {
        attrs,
        vis,
        sig,
        block,
    } = parse_macro_input!(input as ItemFn);
    let context = sig
        .inputs
        .iter()
        .find_map(is_arg_context)
        .unwrap_or_else(|| panic!("Could not find function context argument!"));
    let list = parse_macro_input!(attr as IdentList);
    let hooks = list
        .values
        .into_iter()
        .map(|v| quote! { #context.use_hook(#v); });

    let tokens = quote! {
        #(#attrs)*
        #vis #sig {
            let result = {
                #block
            };
            #(#hooks)*
            result
        }
    };
    tokens.into()
}