async_generic/
lib.rs

1#![deny(warnings)]
2#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg, doc_cfg_hide))]
3
4use proc_macro::{TokenStream, TokenTree};
5use proc_macro2::{Ident, Span, TokenStream as TokenStream2, TokenTree as TokenTree2};
6use quote::quote;
7use syn::{
8    parse::{Parse, ParseStream, Result},
9    parse_macro_input, Attribute, Error, ItemFn, Token,
10};
11
12use crate::desugar_if_async::DesugarIfAsync;
13
14mod desugar_if_async;
15
16fn convert_sync_async(
17    input: &mut Item,
18    is_async: bool,
19    alt_sig: Option<TokenStream>,
20) -> TokenStream2 {
21    let item = &mut input.0;
22
23    if is_async {
24        item.sig.asyncness = Some(Token![async](Span::call_site()));
25        item.sig.ident = Ident::new(&format!("{}_async", item.sig.ident), Span::call_site());
26    }
27
28    let tokens = quote!(#item);
29
30    let tokens = if let Some(alt_sig) = alt_sig {
31        let mut found_fn = false;
32        let mut found_args = false;
33
34        let old_tokens = tokens.into_iter().map(|token| match &token {
35            TokenTree2::Ident(i) => {
36                found_fn = found_fn || &i.to_string() == "fn";
37                token
38            }
39            TokenTree2::Group(g) => {
40                if found_fn && !found_args && g.delimiter() == proc_macro2::Delimiter::Parenthesis {
41                    found_args = true;
42                    return TokenTree2::Group(proc_macro2::Group::new(
43                        proc_macro2::Delimiter::Parenthesis,
44                        alt_sig.clone().into(),
45                    ));
46                }
47                token
48            }
49            _ => token,
50        });
51
52        TokenStream2::from_iter(old_tokens)
53    } else {
54        tokens
55    };
56
57    DesugarIfAsync { is_async }.desugar_if_async(tokens)
58}
59
60#[proc_macro_attribute]
61pub fn async_generic(args: TokenStream, input: TokenStream) -> TokenStream {
62    let mut async_signature: Option<TokenStream> = None;
63
64    if !args.to_string().is_empty() {
65        let mut atokens = args.into_iter();
66        loop {
67            if let Some(TokenTree::Ident(i)) = atokens.next() {
68                if i.to_string() != *"async_signature" {
69                    break;
70                }
71            } else {
72                break;
73            }
74
75            if let Some(TokenTree::Group(g)) = atokens.next() {
76                if atokens.next().is_none() && g.delimiter() == proc_macro::Delimiter::Parenthesis {
77                    async_signature = Some(g.stream());
78                }
79            }
80        }
81
82        if async_signature.is_none() {
83            return syn::Error::new(
84                Span::call_site(),
85                "async_generic can only take a async_signature argument",
86            )
87            .to_compile_error()
88            .into();
89        }
90    };
91
92    let input_clone = input.clone();
93    let mut item = parse_macro_input!(input_clone as Item);
94    let sync_tokens = convert_sync_async(&mut item, false, None);
95
96    let mut item = parse_macro_input!(input as Item);
97    let async_tokens = convert_sync_async(&mut item, true, async_signature);
98
99    let mut tokens = sync_tokens;
100    tokens.extend(async_tokens);
101    tokens.into()
102}
103
104struct Item(ItemFn);
105
106impl Parse for Item {
107    fn parse(input: ParseStream) -> Result<Self> {
108        let attrs = input.call(Attribute::parse_outer)?;
109        if let Ok(mut item) = input.parse::<ItemFn>() {
110            item.attrs = attrs;
111            if item.sig.asyncness.is_some() {
112                return Err(Error::new(
113                    Span::call_site(),
114                    "an async_generic function should not be declared as async",
115                ));
116            }
117            Ok(Item(item))
118        } else {
119            Err(Error::new(
120                Span::call_site(),
121                "async_generic can only be used with functions",
122            ))
123        }
124    }
125}