#![forbid(unsafe_code, future_incompatible, rust_2018_idioms)]
#![deny(missing_debug_implementations, nonstandard_style)]
#![recursion_limit = "512"]
mod tokio;
use proc_macro::TokenStream;
use quote::{quote, quote_spanned};
use syn::spanned::Spanned;
#[cfg(not(test))] #[proc_macro_attribute]
pub fn async_std_main(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(item as syn::ItemFn);
let ret = &input.sig.output;
let inputs = &input.sig.inputs;
let name = &input.sig.ident;
let body = &input.block;
let attrs = &input.attrs;
let vis = &input.vis;
if name != "main" {
return TokenStream::from(quote_spanned! { name.span() =>
compile_error!("only the main function can be tagged with #[async_std::main]"),
});
}
if input.sig.asyncness.is_none() {
return TokenStream::from(quote_spanned! { input.span() =>
compile_error!("the async keyword is missing from the function declaration"),
});
}
let result = quote! {
#vis fn main() {
#(#attrs)*
async fn main(#inputs) #ret {
#body
}
pyo3::prepare_freethreaded_python();
pyo3::Python::with_gil(|py| {
pyo3_asyncio::async_std::run(py, main())
.map_err(|e| {
e.print_and_set_sys_last_vars(py);
})
.unwrap();
});
}
};
result.into()
}
#[cfg(not(test))] #[proc_macro_attribute]
pub fn tokio_main(args: TokenStream, item: TokenStream) -> TokenStream {
tokio::main(args, item, true)
}
#[cfg(not(test))] #[proc_macro_attribute]
pub fn async_std_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(item as syn::ItemFn);
let sig = &input.sig;
let name = &input.sig.ident;
let body = &input.block;
let vis = &input.vis;
let fn_impl = if input.sig.asyncness.is_none() {
let task = if sig.inputs.is_empty() {
quote! {
Box::pin(pyo3_asyncio::async_std::re_exports::spawn_blocking(move || {
#name()
}))
}
} else {
quote! {
let event_loop = Python::with_gil(|py| {
pyo3_asyncio::async_std::get_current_loop(py).unwrap().into()
});
Box::pin(pyo3_asyncio::async_std::re_exports::spawn_blocking(move || {
#name(event_loop)
}))
}
};
quote! {
#vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
#sig {
#body
}
#task
}
}
} else {
quote! {
#vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
#sig {
#body
}
Box::pin(#name())
}
}
};
let result = quote! {
#fn_impl
pyo3_asyncio::inventory::submit! {
pyo3_asyncio::testing::Test {
name: concat!(std::module_path!(), "::", stringify!(#name)),
test_fn: &#name
}
}
};
result.into()
}
#[cfg(not(test))] #[proc_macro_attribute]
pub fn tokio_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(item as syn::ItemFn);
let sig = &input.sig;
let name = &input.sig.ident;
let body = &input.block;
let vis = &input.vis;
let fn_impl = if input.sig.asyncness.is_none() {
let task = if sig.inputs.is_empty() {
quote! {
Box::pin(async move {
match pyo3_asyncio::tokio::get_runtime().spawn_blocking(move || #name()).await {
Ok(result) => result,
Err(e) => {
assert!(e.is_panic());
let panic = e.into_panic();
let panic_message = if let Some(s) = panic.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = panic.downcast_ref::<String>() {
s.clone()
} else {
"unknown error".into()
};
Err(pyo3_asyncio::err::RustPanic::new_err(format!("rust future panicked: {}", panic_message)))
}
}
})
}
} else {
quote! {
let event_loop = Python::with_gil(|py| {
pyo3_asyncio::tokio::get_current_loop(py).unwrap().into()
});
Box::pin(async move {
match pyo3_asyncio::tokio::get_runtime().spawn_blocking(move || #name(event_loop)).await {
Ok(result) => result,
Err(e) => {
assert!(e.is_panic());
let panic = e.into_panic();
let panic_message = if let Some(s) = panic.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = panic.downcast_ref::<String>() {
s.clone()
} else {
"unknown error".into()
};
Err(pyo3_asyncio::err::RustPanic::new_err(format!("rust future panicked: {}", panic_message)))
}
}
})
}
};
quote! {
#vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
#sig {
#body
}
#task
}
}
} else {
quote! {
#vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
#sig {
#body
}
Box::pin(#name())
}
}
};
let result = quote! {
#fn_impl
pyo3_asyncio::inventory::submit! {
pyo3_asyncio::testing::Test {
name: concat!(std::module_path!(), "::", stringify!(#name)),
test_fn: &#name
}
}
};
result.into()
}