1use proc_macro::TokenStream;
2use quote::{format_ident, quote, ToTokens};
3use syn::{
4 parse::Parser, parse_macro_input, parse_quote, parse_quote_spanned, punctuated::Punctuated,
5 spanned::Spanned,
6};
7
8const MODULES: [&str; 3] = ["asyncio", "trio", "sniffio"];
9
10macro_rules! unwrap {
11 ($result:expr) => {
12 match $result {
13 Ok(ok) => ok,
14 Err(err) => return err.into_compile_error().into(),
15 }
16 };
17}
18
19struct Options {
20 module: syn::Path,
21 allow_threads: bool,
22}
23
24fn parse_options(attr: TokenStream) -> syn::Result<Options> {
25 let mut allow_threads = false;
26 let mut module = None;
27 let module_parser = syn::meta::parser(|meta| {
28 if meta.path.is_ident("allow_threads") {
29 allow_threads = true;
30 } else if MODULES.iter().any(|m| meta.path.is_ident(m)) {
31 if module.is_some() {
32 return Err(meta.error("multiple Python async backend specified"));
33 }
34 module = Some(meta.path);
35 } else {
36 return Err(meta.error("invalid option"));
37 }
38 Ok(())
39 });
40 module_parser.parse(attr)?;
41 Ok(Options {
42 module: module.unwrap_or_else(|| parse_quote!(asyncio)),
43 allow_threads,
44 })
45}
46
47fn build_coroutine(
48 path: impl ToTokens,
49 attrs: &mut Vec<syn::Attribute>,
50 sig: &mut syn::Signature,
51 block: &mut syn::Block,
52 options: &Options,
53) -> syn::Result<()> {
54 attrs.retain(|attr| attr.meta.path().is_ident("pyo3"));
55 let mut has_name = false;
56 for attr in attrs.iter() {
57 has_name |= attr
58 .parse_args_with(Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated)?
59 .into_iter()
60 .any(|meta| matches!(meta, syn::Meta::NameValue(nv) if nv.path.is_ident("name")));
61 }
62 if !has_name {
63 let name = format!("{}", &sig.ident);
64 attrs.push(parse_quote!(#[pyo3(name = #name)]));
65 }
66 let ident = sig.ident.clone();
67 sig.ident = format_ident!("async_{ident}");
68 sig.asyncness = None;
69 let module = &options.module;
70 let coro_path = quote!(::pyo3_async::#module::Coroutine);
71 let params = sig.inputs.iter().map(|arg| match arg {
72 syn::FnArg::Receiver(_) => quote!(self),
73 syn::FnArg::Typed(syn::PatType { pat, .. }) => quote!(#pat),
74 });
75 let mut future = quote!(#path(#(#params),*));
76 if matches!(sig.output, syn::ReturnType::Default) {
77 future = quote!(async move {#future.await; pyo3::PyResult::Ok(())})
78 }
79 if options.allow_threads {
80 future = quote!(::pyo3_async::AllowThreads(#future));
81 }
82 block.stmts = vec![parse_quote_spanned! { block.span() =>
84 #[allow(clippy::needless_return)]
85 return #coro_path::from_future(#future);
86 }];
87 sig.output = parse_quote_spanned!(sig.output.span() => -> #coro_path);
88 Ok(())
89}
90
91#[proc_macro_attribute]
125pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream {
126 let options = unwrap!(parse_options(attr));
127 let mut func = parse_macro_input!(input as syn::ItemFn);
128 if func.sig.asyncness.is_none() {
129 return quote!(#[::pyo3::pyfunction] #func).into();
130 }
131 let mut coro = func.clone();
132 unwrap!(build_coroutine(
133 &func.sig.ident,
134 &mut coro.attrs,
135 &mut coro.sig,
136 &mut coro.block,
137 &options
138 ));
139 func.attrs.retain(|attr| !attr.meta.path().is_ident("pyo3"));
140 let expanded = quote! {
141 #func
142 #[::pyo3::pyfunction]
143 #coro
144 };
145 expanded.into()
146}
147
148#[proc_macro_attribute]
212pub fn pymethods(attr: TokenStream, input: TokenStream) -> TokenStream {
213 let options = unwrap!(parse_options(attr));
214 let mut r#impl = parse_macro_input!(input as syn::ItemImpl);
215 let (async_methods, items) = r#impl.items.into_iter().partition::<Vec<_>, _>(
216 |item| matches!(item, syn::ImplItem::Fn(func) if func.sig.asyncness.is_some()),
217 );
218 r#impl.items = items;
219 if async_methods.is_empty() {
220 return quote!(#[::pyo3::pymethods] #r#impl).into();
221 }
222 let mut async_impl = r#impl.clone();
223 async_impl.items = async_methods;
224 async_impl.attrs.clear();
225 for item in &mut async_impl.items {
226 let syn::ImplItem::Fn(method) = item else {
227 unreachable!()
228 };
229 let mut coro = method.clone();
230 let self_ty = &r#impl.self_ty;
231 let method_name = &method.sig.ident;
232 unwrap!(build_coroutine(
233 quote!(#self_ty::#method_name),
234 &mut coro.attrs,
235 &mut coro.sig,
236 &mut coro.block,
237 &options
238 ));
239 method
240 .attrs
241 .retain(|attr| !attr.meta.path().is_ident("pyo3"));
242 method.attrs.retain(|attr| {
243 if ["getter", "classmethod", "staticmethod"]
244 .iter()
245 .any(|m| attr.meta.path().is_ident(m))
246 {
247 coro.attrs.push(attr.clone());
248 return false;
249 }
250 true
251 });
252 r#impl.items.push(syn::ImplItem::Fn(coro));
253 }
254 let expanded = quote! {
255 #[::pyo3::pymethods]
256 #r#impl
257 #async_impl
258 };
259 expanded.into()
260}