desaturate_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::{quote_spanned, ToTokens};
4use syn::{
5    parse::{Parse, ParseStream},
6    parse2,
7    punctuated::Punctuated,
8    spanned::Spanned,
9    Ident, Lifetime, Token,
10};
11
12#[cfg(all(feature = "dont-directly-import-this-crate", not(doc), not(test)))]
13compile_error! {"Directly importing the `desaturate-macros` crate may make generated functions unsound, as they require that the feature flags of this crate match with the `desaturate` crate."}
14
15pub(crate) fn default<T: Default>() -> T {
16    T::default()
17}
18
19mod input_function;
20mod transformer;
21mod visitors;
22use crate::{input_function::*, transformer::*};
23
24#[derive(Default)]
25struct Asyncable {
26    debug_dump: Option<Span>,
27    lifetime: Option<Lifetime>,
28    only_async_attr: Option<Ident>,
29    only_blocking_attr: Option<Ident>,
30    make_blocking: bool,
31    make_async: bool,
32}
33
34impl Parse for Asyncable {
35    fn parse(input: ParseStream) -> syn::Result<Self> {
36        struct Setting {
37            name: Ident,
38            value: Option<(Token![=], syn::Lit)>,
39        }
40        impl Setting {
41            fn span(&self) -> Span {
42                let span = self.name.span();
43                if let Some((token, lit)) = &self.value {
44                    span.join(token.span()).unwrap().join(lit.span()).unwrap()
45                } else {
46                    span
47                }
48            }
49        }
50        impl Parse for Setting {
51            fn parse(input: ParseStream) -> syn::Result<Self> {
52                let name = input.parse()?;
53                let value = if input.peek(Token![=]) {
54                    Some((input.parse()?, input.parse()?))
55                } else {
56                    None
57                };
58                Ok(Self { name, value })
59            }
60        }
61        let mut result = Asyncable {
62            make_blocking: cfg!(feature = "generate-blocking"),
63            make_async: cfg!(feature = "generate-async"),
64            ..Asyncable::default()
65        };
66        let mut errors: Vec<syn::Error> = vec![];
67        Punctuated::<Setting, Token![,]>::parse_terminated(input)?
68            .iter()
69            .for_each(|setting| match setting {
70                setting @ Setting { name, value: None } if name == "debug_dump" => {
71                    result.debug_dump = Some(setting.span())
72                }
73                Setting {
74                    name,
75                    value: Some((_eq, syn::Lit::Str(ident))),
76                } if name == "only_async_attr" => match ident.parse() {
77                    Ok(ident) => result.only_async_attr = ident,
78                    Err(e) => errors.push(e),
79                },
80                Setting {
81                    name,
82                    value: Some((_eq, syn::Lit::Str(ident))),
83                } if name == "only_blocking_attr" => match ident.parse() {
84                    Ok(ident) => result.only_blocking_attr = ident,
85                    Err(e) => errors.push(e),
86                },
87                Setting {
88                    name,
89                    value: Some((_eq, syn::Lit::Str(lifetime))),
90                } if name == "lifetime" => match lifetime.parse() {
91                    Ok(lifetime) => result.lifetime = lifetime,
92                    Err(e) => errors.push(e),
93                },
94                invalid => errors.push(syn::Error::new(invalid.span(), "Invalid argument")),
95            });
96        let errors = errors
97            .into_iter()
98            .fold(Option::<syn::Error>::None, |prev, err| {
99                Some(prev.map_or(err.clone(), move |mut old_err| {
100                    old_err.combine(err);
101                    old_err
102                }))
103            });
104        if let Some(errors) = errors {
105            Err(errors)
106        } else {
107            Ok(result)
108        }
109    }
110}
111
112impl Asyncable {
113    fn from_attributes(input: TokenStream2) -> syn::Result<Self> {
114        parse2(input)
115    }
116}
117
118struct PrintFunctionState<'a, 'b: 'a> {
119    state: &'a FunctionState<'b>,
120}
121
122impl ToTokens for PrintFunctionState<'_, '_> {
123    fn to_tokens(&self, tokens: &mut TokenStream2) {
124        let PrintFunctionState {
125            state:
126                state @ FunctionState {
127                    options:
128                        Asyncable {
129                            make_blocking,
130                            make_async,
131                            ..
132                        },
133                    function:
134                        AsyncFunction {
135                            visibility,
136                            constness,
137                            _asyncness,
138                            unsafety,
139                            fn_token,
140                            ident,
141                            generics: _, // replaced with new_generics()
142                            paren_token,
143                            inputs: _, // replaced with simple_input_variables()
144                            output: _, // replaced with new_return_type_tokens()
145                            where_clause,
146                            body,
147                            identities: _,
148                        },
149                    ..
150                },
151        } = self;
152        visibility.to_tokens(tokens);
153        constness.to_tokens(tokens);
154        unsafety.to_tokens(tokens);
155        fn_token.to_tokens(tokens);
156        ident.to_tokens(tokens);
157        state.new_generics().to_tokens(tokens);
158        paren_token.surround(tokens, |tokens| {
159            state.simple_input_variables().to_tokens(tokens)
160        });
161        state.new_return_type_tokens().to_tokens(tokens);
162        where_clause.to_tokens(tokens);
163        body.brace_token.surround(tokens, |tokens| {
164            if *make_blocking && *make_async {
165                let async_let = state.async_let_statement();
166                let blocking_let = state.blocking_let_statement();
167                let async_var = state.async_name();
168                let blocking_var = state.blocking_name();
169                let args_var = state.simple_input_variables_tuple();
170                // TODO: Add parantesis to args
171                quote_spanned!{body.span()=>
172                    #async_let;
173                    #blocking_let;
174                    ::desaturate::IntoDesaturatedWith::desaturate_with(#async_var, #args_var, #blocking_var)
175                }.to_tokens(tokens);
176            } else if *make_blocking {
177                let blocking_body = state.blocking_function_body();
178                let warning = format!("Tried to await Desaturated from {} when desaturate wasn't compiled with \"async\"", state.function.ident);
179                quote_spanned!{body.span()=>
180                    ::desaturate::IntoDesaturated::desaturate(async { unreachable!(#warning) }, move || #blocking_body)
181                }.to_tokens(tokens);
182            } else if *make_async {
183                let async_body = &state.body;
184                let warning = format!("Tried to call Desaturated from {} when desaturate wasn't compiled with \"blocking\"", state.function.ident);
185                quote_spanned!{body.span()=>
186                    ::desaturate::IntoDesaturated::desaturate(async move #async_body, || unreachable!(#warning))
187                }.to_tokens(tokens);
188            } else {
189                let async_warning = format!("Tried to await Desaturated from {} when desaturate wasn't compiled with \"async\"", state.function.ident);
190                let blocking_warning = format!("Tried to call Desaturated from {} when desaturate wasn't compiled with \"blocking\"", state.function.ident);
191                quote_spanned!{body.span()=>
192                    ::desaturate::IntoDesaturated::desaturate(async { unreachable!(#async_warning) }, || unreachable!(#blocking_warning))
193                }.to_tokens(tokens);
194            }
195        });
196    }
197}
198
199impl Asyncable {
200    fn desaturate(&self, item: TokenStream2) -> syn::Result<TokenStream2> {
201        let function: AsyncFunction = parse2(item)?;
202        let state = FunctionState::new(self, &function);
203        let result = PrintFunctionState { state: &state }.into_token_stream();
204        if self.debug_dump.is_some() {
205            eprintln!("{result}");
206        }
207        Ok(result)
208    }
209}
210
211#[proc_macro_attribute]
212pub fn desaturate(attr: TokenStream, item: TokenStream) -> TokenStream {
213    match Asyncable::from_attributes(attr.into()) {
214        Ok(handler) => handler
215            .desaturate(item.into())
216            .unwrap_or_else(syn::Error::into_compile_error)
217            .into(),
218        Err(e) => e.into_compile_error().into(),
219    }
220}