1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::{quote_spanned, ToTokens};
4use syn::{
5 parse::{Parse, ParseStream},
6 parse2,
7 punctuated::Punctuated,
8 spanned::Spanned,
9 Ident, Lifetime, Token,
10};
11
12#[cfg(all(feature = "dont-directly-import-this-crate", not(doc), not(test)))]
13compile_error! {"Directly importing the `desaturate-macros` crate may make generated functions unsound, as they require that the feature flags of this crate match with the `desaturate` crate."}
14
15pub(crate) fn default<T: Default>() -> T {
16 T::default()
17}
18
19mod input_function;
20mod transformer;
21mod visitors;
22use crate::{input_function::*, transformer::*};
23
24#[derive(Default)]
25struct Asyncable {
26 debug_dump: Option<Span>,
27 lifetime: Option<Lifetime>,
28 only_async_attr: Option<Ident>,
29 only_blocking_attr: Option<Ident>,
30 make_blocking: bool,
31 make_async: bool,
32}
33
34impl Parse for Asyncable {
35 fn parse(input: ParseStream) -> syn::Result<Self> {
36 struct Setting {
37 name: Ident,
38 value: Option<(Token![=], syn::Lit)>,
39 }
40 impl Setting {
41 fn span(&self) -> Span {
42 let span = self.name.span();
43 if let Some((token, lit)) = &self.value {
44 span.join(token.span()).unwrap().join(lit.span()).unwrap()
45 } else {
46 span
47 }
48 }
49 }
50 impl Parse for Setting {
51 fn parse(input: ParseStream) -> syn::Result<Self> {
52 let name = input.parse()?;
53 let value = if input.peek(Token![=]) {
54 Some((input.parse()?, input.parse()?))
55 } else {
56 None
57 };
58 Ok(Self { name, value })
59 }
60 }
61 let mut result = Asyncable {
62 make_blocking: cfg!(feature = "generate-blocking"),
63 make_async: cfg!(feature = "generate-async"),
64 ..Asyncable::default()
65 };
66 let mut errors: Vec<syn::Error> = vec![];
67 Punctuated::<Setting, Token![,]>::parse_terminated(input)?
68 .iter()
69 .for_each(|setting| match setting {
70 setting @ Setting { name, value: None } if name == "debug_dump" => {
71 result.debug_dump = Some(setting.span())
72 }
73 Setting {
74 name,
75 value: Some((_eq, syn::Lit::Str(ident))),
76 } if name == "only_async_attr" => match ident.parse() {
77 Ok(ident) => result.only_async_attr = ident,
78 Err(e) => errors.push(e),
79 },
80 Setting {
81 name,
82 value: Some((_eq, syn::Lit::Str(ident))),
83 } if name == "only_blocking_attr" => match ident.parse() {
84 Ok(ident) => result.only_blocking_attr = ident,
85 Err(e) => errors.push(e),
86 },
87 Setting {
88 name,
89 value: Some((_eq, syn::Lit::Str(lifetime))),
90 } if name == "lifetime" => match lifetime.parse() {
91 Ok(lifetime) => result.lifetime = lifetime,
92 Err(e) => errors.push(e),
93 },
94 invalid => errors.push(syn::Error::new(invalid.span(), "Invalid argument")),
95 });
96 let errors = errors
97 .into_iter()
98 .fold(Option::<syn::Error>::None, |prev, err| {
99 Some(prev.map_or(err.clone(), move |mut old_err| {
100 old_err.combine(err);
101 old_err
102 }))
103 });
104 if let Some(errors) = errors {
105 Err(errors)
106 } else {
107 Ok(result)
108 }
109 }
110}
111
112impl Asyncable {
113 fn from_attributes(input: TokenStream2) -> syn::Result<Self> {
114 parse2(input)
115 }
116}
117
118struct PrintFunctionState<'a, 'b: 'a> {
119 state: &'a FunctionState<'b>,
120}
121
122impl ToTokens for PrintFunctionState<'_, '_> {
123 fn to_tokens(&self, tokens: &mut TokenStream2) {
124 let PrintFunctionState {
125 state:
126 state @ FunctionState {
127 options:
128 Asyncable {
129 make_blocking,
130 make_async,
131 ..
132 },
133 function:
134 AsyncFunction {
135 visibility,
136 constness,
137 _asyncness,
138 unsafety,
139 fn_token,
140 ident,
141 generics: _, paren_token,
143 inputs: _, output: _, where_clause,
146 body,
147 identities: _,
148 },
149 ..
150 },
151 } = self;
152 visibility.to_tokens(tokens);
153 constness.to_tokens(tokens);
154 unsafety.to_tokens(tokens);
155 fn_token.to_tokens(tokens);
156 ident.to_tokens(tokens);
157 state.new_generics().to_tokens(tokens);
158 paren_token.surround(tokens, |tokens| {
159 state.simple_input_variables().to_tokens(tokens)
160 });
161 state.new_return_type_tokens().to_tokens(tokens);
162 where_clause.to_tokens(tokens);
163 body.brace_token.surround(tokens, |tokens| {
164 if *make_blocking && *make_async {
165 let async_let = state.async_let_statement();
166 let blocking_let = state.blocking_let_statement();
167 let async_var = state.async_name();
168 let blocking_var = state.blocking_name();
169 let args_var = state.simple_input_variables_tuple();
170 quote_spanned!{body.span()=>
172 #async_let;
173 #blocking_let;
174 ::desaturate::IntoDesaturatedWith::desaturate_with(#async_var, #args_var, #blocking_var)
175 }.to_tokens(tokens);
176 } else if *make_blocking {
177 let blocking_body = state.blocking_function_body();
178 let warning = format!("Tried to await Desaturated from {} when desaturate wasn't compiled with \"async\"", state.function.ident);
179 quote_spanned!{body.span()=>
180 ::desaturate::IntoDesaturated::desaturate(async { unreachable!(#warning) }, move || #blocking_body)
181 }.to_tokens(tokens);
182 } else if *make_async {
183 let async_body = &state.body;
184 let warning = format!("Tried to call Desaturated from {} when desaturate wasn't compiled with \"blocking\"", state.function.ident);
185 quote_spanned!{body.span()=>
186 ::desaturate::IntoDesaturated::desaturate(async move #async_body, || unreachable!(#warning))
187 }.to_tokens(tokens);
188 } else {
189 let async_warning = format!("Tried to await Desaturated from {} when desaturate wasn't compiled with \"async\"", state.function.ident);
190 let blocking_warning = format!("Tried to call Desaturated from {} when desaturate wasn't compiled with \"blocking\"", state.function.ident);
191 quote_spanned!{body.span()=>
192 ::desaturate::IntoDesaturated::desaturate(async { unreachable!(#async_warning) }, || unreachable!(#blocking_warning))
193 }.to_tokens(tokens);
194 }
195 });
196 }
197}
198
199impl Asyncable {
200 fn desaturate(&self, item: TokenStream2) -> syn::Result<TokenStream2> {
201 let function: AsyncFunction = parse2(item)?;
202 let state = FunctionState::new(self, &function);
203 let result = PrintFunctionState { state: &state }.into_token_stream();
204 if self.debug_dump.is_some() {
205 eprintln!("{result}");
206 }
207 Ok(result)
208 }
209}
210
211#[proc_macro_attribute]
212pub fn desaturate(attr: TokenStream, item: TokenStream) -> TokenStream {
213 match Asyncable::from_attributes(attr.into()) {
214 Ok(handler) => handler
215 .desaturate(item.into())
216 .unwrap_or_else(syn::Error::into_compile_error)
217 .into(),
218 Err(e) => e.into_compile_error().into(),
219 }
220}