pyo3_async_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote, ToTokens};
3use syn::{
4    parse::Parser, parse_macro_input, parse_quote, parse_quote_spanned, punctuated::Punctuated,
5    spanned::Spanned,
6};
7
8const MODULES: [&str; 3] = ["asyncio", "trio", "sniffio"];
9
10macro_rules! unwrap {
11    ($result:expr) => {
12        match $result {
13            Ok(ok) => ok,
14            Err(err) => return err.into_compile_error().into(),
15        }
16    };
17}
18
19struct Options {
20    module: syn::Path,
21    allow_threads: bool,
22}
23
24fn parse_options(attr: TokenStream) -> syn::Result<Options> {
25    let mut allow_threads = false;
26    let mut module = None;
27    let module_parser = syn::meta::parser(|meta| {
28        if meta.path.is_ident("allow_threads") {
29            allow_threads = true;
30        } else if MODULES.iter().any(|m| meta.path.is_ident(m)) {
31            if module.is_some() {
32                return Err(meta.error("multiple Python async backend specified"));
33            }
34            module = Some(meta.path);
35        } else {
36            return Err(meta.error("invalid option"));
37        }
38        Ok(())
39    });
40    module_parser.parse(attr)?;
41    Ok(Options {
42        module: module.unwrap_or_else(|| parse_quote!(asyncio)),
43        allow_threads,
44    })
45}
46
47fn build_coroutine(
48    path: impl ToTokens,
49    attrs: &mut Vec<syn::Attribute>,
50    sig: &mut syn::Signature,
51    block: &mut syn::Block,
52    options: &Options,
53) -> syn::Result<()> {
54    attrs.retain(|attr| attr.meta.path().is_ident("pyo3"));
55    let mut has_name = false;
56    for attr in attrs.iter() {
57        has_name |= attr
58            .parse_args_with(Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated)?
59            .into_iter()
60            .any(|meta| matches!(meta, syn::Meta::NameValue(nv) if nv.path.is_ident("name")));
61    }
62    if !has_name {
63        let name = format!("{}", &sig.ident);
64        attrs.push(parse_quote!(#[pyo3(name = #name)]));
65    }
66    let ident = sig.ident.clone();
67    sig.ident = format_ident!("async_{ident}");
68    sig.asyncness = None;
69    let module = &options.module;
70    let coro_path = quote!(::pyo3_async::#module::Coroutine);
71    let params = sig.inputs.iter().map(|arg| match arg {
72        syn::FnArg::Receiver(_) => quote!(self),
73        syn::FnArg::Typed(syn::PatType { pat, .. }) => quote!(#pat),
74    });
75    let mut future = quote!(#path(#(#params),*));
76    if matches!(sig.output, syn::ReturnType::Default) {
77        future = quote!(async move {#future.await; pyo3::PyResult::Ok(())})
78    }
79    if options.allow_threads {
80        future = quote!(::pyo3_async::AllowThreads(#future));
81    }
82    // return statement because `parse_quote_spanned` doesn't work otherwise
83    block.stmts = vec![parse_quote_spanned! { block.span() =>
84        #[allow(clippy::needless_return)]
85        return #coro_path::from_future(#future);
86    }];
87    sig.output = parse_quote_spanned!(sig.output.span() => -> #coro_path);
88    Ok(())
89}
90
91/// [`pyo3::pyfunction`] with async support.
92///
93/// Generate a additional function prefixed by `async_`, decorated by [`pyo3::pyfunction`] and
94/// `#[pyo3(name = ...)]`.
95///
96/// Python async backend can be specified using macro argument (default to `asyncio`).
97/// If `allow_threads` is passed in arguments, GIL will be released for future polling (see
98/// [`AllowThreads`])
99///
100/// # Example
101///
102/// ```rust
103/// #[pyo3_async::pyfunction(allow_threads)]
104/// pub async fn print(s: String) {
105///     println!("{s}");
106/// }
107/// ```
108/// generates
109/// ```rust
110/// pub async fn print(s: String) {
111///     println!("{s}");
112/// }
113/// #[::pyo3::pyfunction]
114/// #[pyo3(name = "print")]
115/// pub fn async_print(s: String) -> ::pyo3_async::asyncio::Coroutine {
116///     ::pyo3_async::asyncio::Coroutine::from_future(::pyo3_async::AllowThreads(
117///         async move { print(s).await; Ok(()) }
118///     ))
119/// }
120/// ```
121///
122/// [`pyo3::pyfunction`]: https://docs.rs/pyo3/latest/pyo3/attr.pyfunction.html
123/// [`AllowThreads`]: https://docs.rs/pyo3-async/latest/pyo3_async/struct.AllowThreads.html
124#[proc_macro_attribute]
125pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream {
126    let options = unwrap!(parse_options(attr));
127    let mut func = parse_macro_input!(input as syn::ItemFn);
128    if func.sig.asyncness.is_none() {
129        return quote!(#[::pyo3::pyfunction] #func).into();
130    }
131    let mut coro = func.clone();
132    unwrap!(build_coroutine(
133        &func.sig.ident,
134        &mut coro.attrs,
135        &mut coro.sig,
136        &mut coro.block,
137        &options
138    ));
139    func.attrs.retain(|attr| !attr.meta.path().is_ident("pyo3"));
140    let expanded = quote! {
141        #func
142        #[::pyo3::pyfunction]
143        #coro
144    };
145    expanded.into()
146}
147
148/// [`pyo3::pymethods`] with async support.
149///
150/// For each async methods, generate a additional function prefixed by `async_`, decorated with
151/// `#[pyo3(name = ...)]`. Original async methods are kept in a separate impl, while the original
152/// impl is decorated with [`pyo3::pymethods`].
153///
154/// Python async backend can be specified using macro argument (default to `asyncio`).
155/// If `allow_threads` is passed in arguments, GIL will be released for future polling (see
156/// [`AllowThreads`])
157///
158/// # Example
159///
160/// ```rust
161/// #[pyo3::pyclass]
162/// struct Counter(usize);
163///
164/// #[pyo3_async::pymethods(trio)]
165/// impl Counter {
166///     fn incr_sync(&mut self) -> usize {
167///         self.0 += 1;
168///         self.0
169///     }
170///
171///     // Arguments needs to implement `Send + 'static`, so `self` must be passed using `Py<Self>`
172///     async fn incr_async(self_: pyo3::Py<Self>) -> pyo3::PyResult<usize> {
173///         pyo3::Python::with_gil(|gil| {
174///             let mut this = self_.borrow_mut(gil);
175///             this.0 += 1;
176///             Ok(this.0)
177///         })
178///     }
179/// }
180/// ```
181/// generates
182/// ```rust
183/// #[pyo3::pyclass]
184/// struct Counter(usize);
185///
186/// #[::pyo3::pymethods]
187/// impl Counter {
188///     fn incr_sync(&mut self) -> usize {
189///         self.0 += 1;
190///         self.0
191///     }
192///
193///     #[pyo3(name = "incr_async")]
194///     fn async_incr_async(self_: pyo3::Py<Self>) -> ::pyo3_async::trio::Coroutine {
195///         ::pyo3_async::trio::Coroutine::from_future(Counter::incr_async(self_))
196///     }
197/// }
198/// impl Counter {
199///     async fn incr_async(self_: pyo3::Py<Self>) -> pyo3::PyResult<usize> {
200///         pyo3::Python::with_gil(|gil| {
201///             let mut this = self_.borrow_mut(gil);
202///             this.0 += 1;
203///             Ok(this.0)
204///         })
205///     }
206/// }
207/// ```
208///
209/// [`pyo3::pymethods`]: https://docs.rs/pyo3/latest/pyo3/attr.pymethods.html
210/// [`AllowThreads`]: https://docs.rs/pyo3-async/latest/pyo3_async/struct.AllowThreads.html
211#[proc_macro_attribute]
212pub fn pymethods(attr: TokenStream, input: TokenStream) -> TokenStream {
213    let options = unwrap!(parse_options(attr));
214    let mut r#impl = parse_macro_input!(input as syn::ItemImpl);
215    let (async_methods, items) = r#impl.items.into_iter().partition::<Vec<_>, _>(
216        |item| matches!(item, syn::ImplItem::Fn(func) if func.sig.asyncness.is_some()),
217    );
218    r#impl.items = items;
219    if async_methods.is_empty() {
220        return quote!(#[::pyo3::pymethods] #r#impl).into();
221    }
222    let mut async_impl = r#impl.clone();
223    async_impl.items = async_methods;
224    async_impl.attrs.clear();
225    for item in &mut async_impl.items {
226        let syn::ImplItem::Fn(method) = item else {
227            unreachable!()
228        };
229        let mut coro = method.clone();
230        let self_ty = &r#impl.self_ty;
231        let method_name = &method.sig.ident;
232        unwrap!(build_coroutine(
233            quote!(#self_ty::#method_name),
234            &mut coro.attrs,
235            &mut coro.sig,
236            &mut coro.block,
237            &options
238        ));
239        method
240            .attrs
241            .retain(|attr| !attr.meta.path().is_ident("pyo3"));
242        method.attrs.retain(|attr| {
243            if ["getter", "classmethod", "staticmethod"]
244                .iter()
245                .any(|m| attr.meta.path().is_ident(m))
246            {
247                coro.attrs.push(attr.clone());
248                return false;
249            }
250            true
251        });
252        r#impl.items.push(syn::ImplItem::Fn(coro));
253    }
254    let expanded = quote! {
255        #[::pyo3::pymethods]
256        #r#impl
257        #async_impl
258    };
259    expanded.into()
260}