use crate::util::token_stream_and_error;
use darling::{ast::NestedMeta, FromMeta};
use proc_macro2::TokenStream;
use quote::quote;
use syn::ItemFn;
#[derive(Debug, Eq, PartialEq)]
pub enum EntrypointType {
Main,
Test,
}
#[derive(Debug, FromMeta)]
struct EntrypointOptions {
global_init_fn: Option<syn::Ident>,
worker_init_fn: Option<syn::Ident>,
}
impl EntrypointOptions {
fn parse(attr: TokenStream) -> syn::Result<Self> {
let attr_args = NestedMeta::parse_meta_list(attr)?;
Ok(EntrypointOptions::from_list(&attr_args)?)
}
}
pub fn entrypoint(
attr: TokenStream,
input: TokenStream,
entrypoint_type: EntrypointType,
) -> TokenStream {
let item_ast = syn::parse2::<ItemFn>(input.clone());
let result = match item_ast {
Ok(item) => core(item, entrypoint_type, attr),
Err(e) => Err(e),
};
match result {
Ok(r) => r,
Err(e) => token_stream_and_error(input, e),
}
}
fn core(
mut item: ItemFn,
entrypoint_type: EntrypointType,
attr: TokenStream,
) -> Result<TokenStream, syn::Error> {
let options = EntrypointOptions::parse(attr)?;
let sig = &mut item.sig;
if sig.asyncness.is_none() {
return Err(syn::Error::new_spanned(
sig.fn_token,
"function must be async to use the #[folo::main] or #[folo::test] attribute",
));
}
sig.asyncness = None;
let attrs = &item.attrs;
let vis = &item.vis;
let body = &item.block;
let test_attr = match entrypoint_type {
EntrypointType::Main => quote! {},
EntrypointType::Test => quote! { #[test] },
};
let global_init = match options.global_init_fn {
Some(ident) => quote! {
#ident();
},
None => quote! {},
};
let worker_init = match options.worker_init_fn {
Some(ident) => quote! {
.worker_init(move || { #ident(); })
},
None => quote! {},
};
Ok(match &sig.output {
syn::ReturnType::Default => quote! {
#(#attrs)*
#test_attr
#vis #sig {
#global_init
let __entrypoint_executor = ::folo::runtime::ExecutorBuilder::new()
#worker_init
.build()
.unwrap();
let __entrypoint_executor_clone = ::std::sync::Arc::clone(&__entrypoint_executor);
__entrypoint_executor.spawn_on_any(async move {
let __entrypoint_remote_join: ::folo::runtime::RemoteJoinHandle<_> = ::folo::runtime::spawn(async move {
(async move #body).await;
__entrypoint_executor_clone.stop();
})
.into();
__entrypoint_remote_join.await;
});
__entrypoint_executor.wait();
}
},
syn::ReturnType::Type(_, ty) => {
quote! {
#(#attrs)*
#test_attr
#vis #sig {
#global_init
let __entrypoint_executor = ::folo::runtime::ExecutorBuilder::new()
#worker_init
.build()
.unwrap();
let __entrypoint_executor_clone = ::std::sync::Arc::clone(&__entrypoint_executor);
let __entrypoint_result_rx = ::std::sync::Arc::new(::std::sync::Mutex::new(Option::<#ty>::None));
let __entrypoint_result_tx = ::std::sync::Arc::clone(&__entrypoint_result_rx);
__entrypoint_executor.spawn_on_any(async move {
let __entrypoint_remote_join: ::folo::runtime::RemoteJoinHandle<_> = ::folo::runtime::spawn(async move {
let __entrypoint_result = (async move #body).await;
*__entrypoint_result_tx
.lock()
.expect("poisoned lock") = Some(__entrypoint_result);
__entrypoint_executor_clone.stop();
})
.into();
__entrypoint_remote_join.await;
});
__entrypoint_executor.wait();
let __entrypoint_result = __entrypoint_result_rx
.lock()
.expect("posioned lock")
.take()
.expect("entrypoint terminated before returning result");
__entrypoint_result
}
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::util::contains_compile_error;
use syn::parse_quote;
#[test]
fn main_returns_default() {
let input = parse_quote! {
async fn main() {
println!("Hello, world!");
yield_now().await;
}
};
let expected = quote! {
fn main() {
let __entrypoint_executor = ::folo::runtime::ExecutorBuilder::new()
.build()
.unwrap();
let __entrypoint_executor_clone = ::std::sync::Arc::clone(&__entrypoint_executor);
__entrypoint_executor.spawn_on_any(async move {
let __entrypoint_remote_join: ::folo::runtime::RemoteJoinHandle<_> = ::folo::runtime::spawn(async move {
(async move {
println!("Hello, world!");
yield_now().await;
}).await;
__entrypoint_executor_clone.stop();
})
.into();
__entrypoint_remote_join.await;
});
__entrypoint_executor.wait();
}
};
assert_eq!(
entrypoint(TokenStream::new(), input, EntrypointType::Main).to_string(),
expected.to_string()
);
}
#[test]
fn main_returns_result() {
let input = parse_quote! {
async fn main() -> Result<(), Box<dyn std::error::Error + Send + 'static> > {
println!("Hello, world!");
yield_now().await;
Ok(())
}
};
let expected = quote! {
fn main() -> Result<(), Box<dyn std::error::Error + Send + 'static> > {
let __entrypoint_executor = ::folo::runtime::ExecutorBuilder::new()
.build()
.unwrap();
let __entrypoint_executor_clone = ::std::sync::Arc::clone(&__entrypoint_executor);
let __entrypoint_result_rx = ::std::sync::Arc::new(::std::sync::Mutex::new(Option::<Result<(), Box<dyn std::error::Error + Send + 'static> > >::None));
let __entrypoint_result_tx = ::std::sync::Arc::clone(&__entrypoint_result_rx);
__entrypoint_executor.spawn_on_any(async move {
let __entrypoint_remote_join: ::folo::runtime::RemoteJoinHandle<_> = ::folo::runtime::spawn(async move {
let __entrypoint_result = (async move {
println!("Hello, world!");
yield_now().await;
Ok(())
}).await;
*__entrypoint_result_tx
.lock()
.expect("poisoned lock") = Some(__entrypoint_result);
__entrypoint_executor_clone.stop();
})
.into();
__entrypoint_remote_join.await;
});
__entrypoint_executor.wait();
let __entrypoint_result = __entrypoint_result_rx
.lock()
.expect("posioned lock")
.take()
.expect("entrypoint terminated before returning result");
__entrypoint_result
}
};
assert_eq!(
entrypoint(TokenStream::new(), input, EntrypointType::Main).to_string(),
expected.to_string()
);
}
#[test]
fn main_not_async_is_error() {
let input = parse_quote! {
fn main() {
println!("Hello, world!");
}
};
assert!(contains_compile_error(&entrypoint(
TokenStream::new(),
input,
EntrypointType::Main,
)));
}
#[test]
fn main_with_init_functions() {
let attr = parse_quote! {
global_init_fn = setup_global,
worker_init_fn = setup_worker,
};
let input = parse_quote! {
async fn main() {
println!("Hello, world!");
yield_now().await;
}
};
let expected = quote! {
fn main() {
setup_global();
let __entrypoint_executor = ::folo::runtime::ExecutorBuilder::new()
.worker_init(move || { setup_worker(); } )
.build()
.unwrap();
let __entrypoint_executor_clone = ::std::sync::Arc::clone(&__entrypoint_executor);
__entrypoint_executor.spawn_on_any(async move {
let __entrypoint_remote_join: ::folo::runtime::RemoteJoinHandle<_> = ::folo::runtime::spawn(async move {
(async move {
println!("Hello, world!");
yield_now().await;
}).await;
__entrypoint_executor_clone.stop();
})
.into();
__entrypoint_remote_join.await;
});
__entrypoint_executor.wait();
}
};
assert_eq!(
entrypoint(attr, input, EntrypointType::Main).to_string(),
expected.to_string()
);
}
#[test]
fn test_returns_default() {
let input = parse_quote! {
async fn my_test() {
yield_now().await;
assert_eq!(2 + 2, 4);
}
};
let expected = quote! {
#[test]
fn my_test() {
let __entrypoint_executor = ::folo::runtime::ExecutorBuilder::new()
.build()
.unwrap();
let __entrypoint_executor_clone = ::std::sync::Arc::clone(&__entrypoint_executor);
__entrypoint_executor.spawn_on_any(async move {
let __entrypoint_remote_join: ::folo::runtime::RemoteJoinHandle<_> = ::folo::runtime::spawn(async move {
(async move {
yield_now().await;
assert_eq!(2 + 2, 4);
}).await;
__entrypoint_executor_clone.stop();
})
.into();
__entrypoint_remote_join.await;
});
__entrypoint_executor.wait();
}
};
assert_eq!(
entrypoint(TokenStream::new(), input, EntrypointType::Test).to_string(),
expected.to_string()
);
}
}