1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
//! `unanchored` attribute macro's implementation. **Consider use the `anchored` crate instead**.
#![feature(proc_macro_quote)]

#[macro_use]
extern crate proc_macro_error;

use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::Comma;
use syn::{
    parse_macro_input, FnArg, Generics, Ident, ItemFn, Pat, Token, TypeParamBound, WhereClause,
    WherePredicate,
};

/// Add this to async function / method to ensure no `Anchored` struct is captured into async
/// generator's state.
#[proc_macro_attribute]
#[proc_macro_error]
pub fn unanchored(attr: TokenStream, item: TokenStream) -> TokenStream {
    let mut result = TokenStream::new();
    result.extend(attr);
    result.extend(item.clone());

    let future = parse_macro_input!(item as ItemFn);

    if future.sig.asyncness.is_none() {
        let msg = "this isn't an async function";
        return syn::Error::new_spanned(future.sig.ident, msg)
            .to_compile_error()
            .into();
    }

    let syn::ItemFn {
        attrs: _attrs,
        vis: _vis,
        block: _block,
        sig,
    } = future;

    let syn::Signature {
        inputs,
        unsafety,
        abi,
        ident,
        generics:
            Generics {
                params: generic_param,
                where_clause,
                ..
            },
        ..
    } = sig;

    let unanchored_ident = format_ident!("__assert_unanchored_{}", ident);

    let (func_params, has_self) = extract_params(&inputs);
    let where_clause = rewrite_where(where_clause);

    let assert_fn: TokenStream = if has_self {
        quote!(
            #unsafety #abi fn #unanchored_ident<#generic_param> (#inputs) #where_clause{
                let future = self.#ident(#(#func_params),*);
                fn assert<UnanchoredFutureType: anchored::Unanchored>(_: UnanchoredFutureType) {}
                assert(future);
            }
        )
    } else {
        quote!(
            #unsafety #abi fn #unanchored_ident<#generic_param> (#inputs) #where_clause{
                let future = #ident(#(#func_params),*);
                fn assert<UnanchoredFutureType: anchored::Unanchored>(_: UnanchoredFutureType) {}
                assert(future);
            }
        )
    }
    .into();

    result.extend(assert_fn);
    result
}

/// Extract parameter `Ident`s from `FnArg`, and tell if this contains `self`.
///
/// E.g.,
/// ```rust, ignore
/// fn(ident_1: Type1, ident_2: Type2)
/// ```
/// returns
/// ```rust, ignore
/// vec![Ident("ident_1"), Ident("ident_2")]
/// ```
fn extract_params(inputs: &Punctuated<FnArg, Comma>) -> (Vec<Ident>, bool) {
    let mut has_self = false;
    let idents: Vec<Ident> = inputs
        .iter()
        .filter_map(|arg| match arg {
            FnArg::Receiver(_) => {
                has_self = true;
                None
            }
            FnArg::Typed(pat_type) => {
                if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
                    if pat_ident.ident == "self" {
                        has_self = true;
                        None
                    } else {
                        Some(pat_ident.ident.clone())
                    }
                } else {
                    None
                }
            }
        })
        .collect();

    (idents, has_self)
}

/// Add `Unanchored` to type params in "where" clause.
///
/// E.g.,
/// ```rust, ignore
/// -> where T: Sync,
/// ```
/// to
/// ```rust, ignore
/// -> where T: Sync + Unanchored,
/// ```
fn rewrite_where(where_clause: Option<WhereClause>) -> Option<WhereClause> {
    if let Some(mut where_clause) = where_clause {
        let span = where_clause.span();
        where_clause.predicates = where_clause
            .predicates
            .into_iter()
            .map(|pred| match pred {
                WherePredicate::Lifetime(lifetime) => WherePredicate::Lifetime(lifetime),
                WherePredicate::Eq(eq) => WherePredicate::Eq(eq),
                WherePredicate::Type(mut ty) => {
                    ty.bounds.push_punct(Token![+](span));
                    ty.bounds
                        .push(syn::parse_str::<TypeParamBound>("anchored::Unanchored").unwrap());

                    WherePredicate::Type(ty)
                }
            })
            .collect();

        Some(where_clause)
    } else {
        None
    }
}