use std::{
future::Future,
marker::PhantomData,
sync::{LazyLock, OnceLock, RwLock},
};
use tokio::runtime::Runtime;
use crate::{
bindgen_runtime::ToNapiValue, sys, Env, Error, JsDeferred, JsUnknown, NapiValue, Result,
};
fn create_runtime() -> Option<Runtime> {
#[cfg(not(target_family = "wasm"))]
{
let runtime = tokio::runtime::Runtime::new().expect("Create tokio runtime failed");
Some(runtime)
}
#[cfg(target_family = "wasm")]
{
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.ok()
}
}
pub(crate) static RT: LazyLock<RwLock<Option<Runtime>>> = LazyLock::new(|| {
if let Some(user_defined_rt) = unsafe { USER_DEFINED_RT.take() } {
RwLock::new(user_defined_rt)
} else {
RwLock::new(create_runtime())
}
});
static mut USER_DEFINED_RT: OnceLock<Option<Runtime>> = OnceLock::new();
pub fn create_custom_tokio_runtime(rt: Runtime) {
unsafe {
USER_DEFINED_RT.get_or_init(move || Some(rt));
}
}
#[cfg(not(any(target_os = "macos", target_family = "wasm")))]
static RT_REFERENCE_COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
#[cfg(not(any(target_os = "macos", target_family = "wasm")))]
pub(crate) fn ensure_runtime() {
use std::sync::atomic::Ordering;
let mut rt = RT.write().unwrap();
if rt.is_none() {
*rt = create_runtime();
}
RT_REFERENCE_COUNT.fetch_add(1, Ordering::Relaxed);
}
#[cfg(not(any(target_os = "macos", target_family = "wasm")))]
pub(crate) unsafe extern "C" fn drop_runtime(_arg: *mut std::ffi::c_void) {
use std::sync::atomic::Ordering;
if RT_REFERENCE_COUNT.fetch_sub(1, Ordering::AcqRel) == 1 {
RT.write().unwrap().take();
}
}
pub fn spawn<F>(fut: F) -> tokio::task::JoinHandle<F::Output>
where
F: 'static + Send + Future<Output = ()>,
{
RT.read()
.unwrap()
.as_ref()
.expect("Tokio runtime is not created")
.spawn(fut)
}
pub fn block_on<F: Future>(fut: F) -> F::Output {
RT.read()
.unwrap()
.as_ref()
.expect("Tokio runtime is not created")
.block_on(fut)
}
pub fn spawn_blocking<F, R>(func: F) -> tokio::task::JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
RT.read()
.unwrap()
.as_ref()
.expect("Tokio runtime is not created")
.spawn_blocking(func)
}
pub fn within_runtime_if_available<F: FnOnce() -> T, T>(f: F) -> T {
let rt_lock = RT.read().unwrap();
let rt_guard = rt_lock
.as_ref()
.expect("Tokio runtime is not created")
.enter();
let ret = f();
drop(rt_guard);
ret
}
struct SendableResolver<
Data: 'static + Send,
R: 'static + FnOnce(sys::napi_env, Data) -> Result<sys::napi_value>,
> {
inner: R,
_data: PhantomData<Data>,
}
unsafe impl<Data: 'static + Send, R: 'static + FnOnce(sys::napi_env, Data) -> Result<sys::napi_value>>
Send for SendableResolver<Data, R>
{
}
impl<Data: 'static + Send, R: 'static + FnOnce(sys::napi_env, Data) -> Result<sys::napi_value>>
SendableResolver<Data, R>
{
fn new(inner: R) -> Self {
Self {
inner,
_data: PhantomData,
}
}
fn resolve(self, env: sys::napi_env, data: Data) -> Result<sys::napi_value> {
(self.inner)(env, data)
}
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn execute_tokio_future<
Data: 'static + Send,
Fut: 'static + Send + Future<Output = std::result::Result<Data, impl Into<Error>>>,
Resolver: 'static + FnOnce(sys::napi_env, Data) -> Result<sys::napi_value>,
>(
env: sys::napi_env,
fut: Fut,
resolver: Resolver,
) -> Result<sys::napi_value> {
let (deferred, promise) = JsDeferred::new(env)?;
#[cfg(not(target_family = "wasm"))]
let deferred_for_panic = deferred.clone();
let sendable_resolver = SendableResolver::new(resolver);
let inner = async move {
match fut.await {
Ok(v) => deferred.resolve(move |env| {
sendable_resolver
.resolve(env.raw(), v)
.map(|v| unsafe { JsUnknown::from_raw_unchecked(env.raw(), v) })
}),
Err(e) => deferred.reject(e.into()),
}
};
#[cfg(not(target_family = "wasm"))]
{
let jh = spawn(inner);
spawn(async move {
if let Err(err) = jh.await {
if let Ok(reason) = err.try_into_panic() {
if let Some(s) = reason.downcast_ref::<&str>() {
deferred_for_panic.reject(Error::new(crate::Status::GenericFailure, s));
} else {
deferred_for_panic.reject(Error::new(
crate::Status::GenericFailure,
"Panic in async function",
));
}
}
}
});
}
#[cfg(target_family = "wasm")]
{
std::thread::spawn(|| {
block_on(inner);
});
}
Ok(promise.0.value)
}
pub struct AsyncBlockBuilder<
V: ToNapiValue + Send + 'static,
F: Future<Output = Result<V>> + Send + 'static,
Dispose: FnOnce(Env) + 'static,
> {
inner: F,
dispose: Option<Dispose>,
}
impl<
V: ToNapiValue + Send + 'static,
F: Future<Output = Result<V>> + Send + 'static,
Dispose: FnOnce(Env),
> AsyncBlockBuilder<V, F, Dispose>
{
pub fn with(inner: F) -> Self {
Self {
inner,
dispose: None,
}
}
pub fn with_dispose(mut self, dispose: Dispose) -> Self {
self.dispose = Some(dispose);
self
}
pub fn build(self, env: Env) -> Result<AsyncBlock<V>> {
Ok(AsyncBlock {
inner: execute_tokio_future(env.0, self.inner, |env, v| unsafe {
if let Some(dispose) = self.dispose {
let env = Env::from_raw(env);
dispose(env);
}
V::to_napi_value(env, v)
})?,
_phantom: PhantomData,
})
}
}
pub struct AsyncBlock<T: ToNapiValue + Send + 'static> {
inner: sys::napi_value,
_phantom: PhantomData<T>,
}
impl<T: ToNapiValue + Send + 'static> ToNapiValue for AsyncBlock<T> {
unsafe fn to_napi_value(_: napi_sys::napi_env, val: Self) -> Result<napi_sys::napi_value> {
Ok(val.inner)
}
}