use std::{
error, fmt,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use crate::{
context::{internal::Env, Context, Cx},
result::{NeonResult, ResultExt, Throw},
sys::{self, tsfn::ThreadsafeFunction},
};
#[cfg(feature = "futures")]
use {
std::future::Future,
std::pin::Pin,
std::task::{self, Poll},
tokio::sync::oneshot,
};
#[cfg(not(feature = "futures"))]
mod oneshot {
use std::sync::mpsc;
pub(super) mod error {
pub use super::mpsc::RecvError;
}
pub(super) struct Receiver<T>(mpsc::Receiver<T>);
impl<T> Receiver<T> {
pub(super) fn blocking_recv(self) -> Result<T, mpsc::RecvError> {
self.0.recv()
}
}
pub(super) fn channel<T>() -> (mpsc::SyncSender<T>, Receiver<T>) {
let (tx, rx) = mpsc::sync_channel(1);
(tx, Receiver(rx))
}
}
type Callback = Box<dyn FnOnce(sys::Env) + Send + 'static>;
#[cfg_attr(docsrs, doc(cfg(feature = "napi-4")))]
pub struct Channel {
state: Arc<ChannelState>,
has_ref: bool,
}
impl fmt::Debug for Channel {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("Channel")
}
}
impl Channel {
pub fn new<'a, C: Context<'a>>(cx: &mut C) -> Self {
Self {
state: Arc::new(ChannelState::new(cx)),
has_ref: true,
}
}
pub fn unref<'a, C: Context<'a>>(&mut self, cx: &mut C) -> &mut Self {
if !self.has_ref {
return self;
}
self.has_ref = false;
self.state.unref(cx);
self
}
pub fn reference<'a, C: Context<'a>>(&mut self, cx: &mut C) -> &mut Self {
if self.has_ref {
return self;
}
self.has_ref = true;
self.state.reference(cx);
self
}
pub fn send<T, F>(&self, f: F) -> JoinHandle<T>
where
T: Send + 'static,
F: FnOnce(Cx) -> NeonResult<T> + Send + 'static,
{
self.try_send(f).unwrap()
}
pub fn try_send<T, F>(&self, f: F) -> Result<JoinHandle<T>, SendError>
where
T: Send + 'static,
F: FnOnce(Cx) -> NeonResult<T> + Send + 'static,
{
let (tx, rx) = oneshot::channel();
let callback = Box::new(move |env| {
let env = Env::from(env);
Cx::with_context(env, move |cx| {
let _ = tx.send(f(cx).map_err(Into::into));
});
});
self.state
.tsfn
.call(callback, None)
.map_err(|_| SendError)?;
Ok(JoinHandle { rx })
}
pub fn has_ref(&self) -> bool {
self.has_ref
}
}
impl Clone for Channel {
fn clone(&self) -> Self {
if !self.has_ref {
return Self {
state: self.state.clone(),
has_ref: false,
};
}
let state = Arc::clone(&self.state);
state.ref_count.fetch_add(1, Ordering::Relaxed);
Self {
state,
has_ref: true,
}
}
}
impl Drop for Channel {
fn drop(&mut self) {
if !self.has_ref {
return;
}
if Arc::strong_count(&self.state) == 1 {
return;
}
let state = Arc::clone(&self.state);
let _ = self.try_send(move |mut cx| {
state.unref(&mut cx);
Ok(())
});
}
}
pub struct JoinHandle<T> {
rx: oneshot::Receiver<Result<T, SendThrow>>,
}
impl<T> JoinHandle<T> {
pub fn join(self) -> Result<T, JoinError> {
Ok(self.rx.blocking_recv()??)
}
}
#[cfg(feature = "futures")]
#[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
impl<T> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> {
match Pin::new(&mut self.rx).poll(cx) {
Poll::Ready(result) => {
let get_result = move || Ok(result??);
Poll::Ready(get_result())
}
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug)]
pub struct JoinError(JoinErrorType);
#[derive(Debug)]
enum JoinErrorType {
Panic,
Throw,
}
impl JoinError {
fn as_str(&self) -> &str {
match &self.0 {
JoinErrorType::Panic => "Closure panicked before returning",
JoinErrorType::Throw => "Closure threw an exception",
}
}
}
impl fmt::Display for JoinError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl error::Error for JoinError {}
impl From<oneshot::error::RecvError> for JoinError {
fn from(_: oneshot::error::RecvError) -> Self {
JoinError(JoinErrorType::Panic)
}
}
pub(crate) struct SendThrow(());
impl From<SendThrow> for JoinError {
fn from(_: SendThrow) -> Self {
JoinError(JoinErrorType::Throw)
}
}
impl From<Throw> for SendThrow {
fn from(_: Throw) -> SendThrow {
SendThrow(())
}
}
impl<T> ResultExt<T> for Result<T, JoinError> {
fn or_throw<'a, C: Context<'a>>(self, cx: &mut C) -> NeonResult<T> {
self.or_else(|err| cx.throw_error(err.as_str()))
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "napi-4")))]
pub struct SendError;
impl fmt::Display for SendError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "SendError")
}
}
impl fmt::Debug for SendError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl error::Error for SendError {}
struct ChannelState {
tsfn: ThreadsafeFunction<Callback>,
ref_count: AtomicUsize,
}
impl ChannelState {
fn new<'a, C: Context<'a>>(cx: &mut C) -> Self {
let tsfn = unsafe { ThreadsafeFunction::new(cx.env().to_raw(), Self::callback) };
Self {
tsfn,
ref_count: AtomicUsize::new(1),
}
}
fn reference<'a, C: Context<'a>>(&self, cx: &mut C) {
if self.ref_count.fetch_add(1, Ordering::Relaxed) != 0 {
return;
}
unsafe {
self.tsfn.reference(cx.env().to_raw());
}
}
fn unref<'a, C: Context<'a>>(&self, cx: &mut C) {
if self.ref_count.fetch_sub(1, Ordering::Relaxed) != 1 {
return;
}
unsafe {
self.tsfn.unref(cx.env().to_raw());
}
}
fn callback(env: Option<sys::Env>, callback: Callback) {
if let Some(env) = env {
callback(env);
} else {
crate::context::internal::IS_RUNNING.with(|v| {
*v.borrow_mut() = false;
});
}
}
}