rspack_napi/
threadsafe_function.rs

1use std::{
2  fmt::Debug,
3  marker::PhantomData,
4  sync::{Arc, OnceLock},
5};
6
7use napi::{
8  Env, JsValue, Status, Unknown, ValueType,
9  bindgen_prelude::{FromNapiValue, JsValuesTupleIntoVec, Promise, TypeName, ValidateNapiValue},
10  sys::{self, napi_env},
11  threadsafe_function::{ThreadsafeFunction as RawThreadsafeFunction, ThreadsafeFunctionCallMode},
12};
13#[cfg(not(feature = "browser"))]
14use oneshot::{Receiver, channel};
15#[cfg(feature = "browser")]
16use rspack_browser::oneshot::{Receiver, channel};
17use rspack_error::{Error, Result};
18
19use crate::{JsCallback, NapiErrorToRspackErrorExt};
20
21type ErrorResolver = dyn FnOnce(Env);
22
23static ERROR_RESOLVER: OnceLock<JsCallback<Box<ErrorResolver>>> = OnceLock::new();
24
25pub struct ThreadsafeFunction<T: 'static + JsValuesTupleIntoVec, R> {
26  inner: Arc<RawThreadsafeFunction<T, Unknown<'static>, T, Status, false, true>>,
27  env: napi_env,
28  _data: PhantomData<R>,
29}
30
31impl<T: 'static + JsValuesTupleIntoVec, R> Debug for ThreadsafeFunction<T, R> {
32  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33    f.debug_struct("ThreadsafeFunction").finish_non_exhaustive()
34  }
35}
36
37impl<T: 'static + JsValuesTupleIntoVec, R> Clone for ThreadsafeFunction<T, R> {
38  fn clone(&self) -> Self {
39    Self {
40      inner: self.inner.clone(),
41      env: self.env,
42      _data: self._data,
43    }
44  }
45}
46
47unsafe impl<T: 'static + JsValuesTupleIntoVec, R> Sync for ThreadsafeFunction<T, R> {}
48unsafe impl<T: 'static + JsValuesTupleIntoVec, R> Send for ThreadsafeFunction<T, R> {}
49
50impl<T: 'static + JsValuesTupleIntoVec, R> FromNapiValue for ThreadsafeFunction<T, R> {
51  unsafe fn from_napi_value(env: sys::napi_env, napi_val: sys::napi_value) -> napi::Result<Self> {
52    let inner = unsafe {
53      <RawThreadsafeFunction<T, Unknown, T, Status, false, true> as FromNapiValue>::from_napi_value(
54        env, napi_val,
55      )
56    }?;
57    let _ = ERROR_RESOLVER
58      .get_or_init(|| unsafe { JsCallback::new(env).expect("should initialize error resolver") });
59    Ok(Self {
60      inner: Arc::new(inner),
61      env,
62      _data: PhantomData,
63    })
64  }
65}
66
67impl<T: 'static + JsValuesTupleIntoVec, R> ThreadsafeFunction<T, R> {
68  async fn resolve_error(&self, err: napi::Error) -> Error {
69    let (tx, rx) = tokio::sync::oneshot::channel::<rspack_error::Error>();
70    ERROR_RESOLVER
71      .get()
72      // SAFETY: The error resolver is initialized in `FromNapiValue::from_napi_value` and it's the only way to create a tsfn.
73      .expect("should have error resolver initialized")
74      .call(Box::new(move |env| {
75        let err = err.to_rspack_error(&env);
76        tx.send(err).expect("failed to resolve js error");
77      }));
78    rx.await.expect("failed to resolve js error")
79  }
80
81  fn call_with_return<D: 'static + FromNapiValue>(&self, value: T) -> Receiver<Result<D>> {
82    let (tx, rx) = channel::<Result<D>>();
83    self
84      .inner
85      .call_with_return_value(value, ThreadsafeFunctionCallMode::NonBlocking, {
86        move |r: napi::Result<Unknown>, env| {
87          let r = match r {
88            Err(err) => Err(err.to_rspack_error(&env)),
89            Ok(o) => {
90              let raw_env = env.raw();
91              let return_value = o.raw();
92              unsafe { D::from_napi_value(raw_env, return_value) }
93                .map_err(|e| pretty_type_error(o, e))
94            }
95          };
96          tx.send(r)
97            .unwrap_or_else(|_| panic!("failed to send tsfn value"));
98          Ok(())
99        }
100      });
101    rx
102  }
103
104  async fn call_async<D: 'static + FromNapiValue>(&self, value: T) -> Result<D> {
105    let rx = self.call_with_return(value);
106    #[cfg(feature = "browser")]
107    let ret = tokio::task::unconstrained(rx)
108      .await
109      .expect("failed to receive tsfn value");
110    #[cfg(not(feature = "browser"))]
111    let ret = rx.await.expect("failed to receive tsfn value");
112    ret
113  }
114}
115
116impl<T: 'static + JsValuesTupleIntoVec, R: 'static + FromNapiValue> ThreadsafeFunction<T, R> {
117  /// Call the JS function.
118  pub async fn call_with_sync(&self, value: T) -> Result<R> {
119    self.call_async::<R>(value).await
120  }
121}
122
123impl<T: 'static + JsValuesTupleIntoVec, R: 'static + FromNapiValue>
124  ThreadsafeFunction<T, Promise<R>>
125{
126  /// Call the JS function.
127  /// If `Promise<T>` is returned, it will be awaited and its value `T` will be returned.
128  /// Otherwise, an [napi::Error] is returned.
129  pub async fn call_with_promise(&self, value: T) -> Result<R> {
130    match self.call_async::<Promise<R>>(value).await {
131      Ok(r) => match r.await {
132        Ok(r) => Ok(r),
133        Err(err) => Err(self.resolve_error(err).await),
134      },
135      Err(err) => Err(err),
136    }
137  }
138}
139
140impl<T: 'static + JsValuesTupleIntoVec + JsValuesTupleIntoVec, R> ValidateNapiValue
141  for ThreadsafeFunction<T, R>
142{
143}
144
145impl<T: 'static + JsValuesTupleIntoVec, R> TypeName for ThreadsafeFunction<T, R> {
146  fn type_name() -> &'static str {
147    "ThreadsafeFunction"
148  }
149
150  fn value_type() -> napi::ValueType {
151    ValueType::Function
152  }
153}
154
155fn pretty_type_error(return_value: Unknown, error: napi::Error) -> rspack_error::Error {
156  let expected_type = match error.status {
157    Status::ObjectExpected => "object",
158    Status::StringExpected => "string",
159    Status::FunctionExpected => "function",
160    Status::NumberExpected => "number",
161    Status::BooleanExpected => "boolean",
162    Status::ArrayExpected => "Array",
163    Status::BigintExpected => "bigint",
164    Status::DateExpected => "Date",
165    Status::ArrayBufferExpected => "ArrayBuffer",
166    _ => return rspack_error::error!("{}", error),
167  };
168  let reason = match return_value.get_type() {
169    Ok(return_value_type) => {
170      let return_value_type_str = match return_value_type {
171        ValueType::Undefined => "undefined",
172        ValueType::Null => "null",
173        ValueType::Boolean => "boolean",
174        ValueType::Number => "number",
175        ValueType::String => "string",
176        ValueType::Symbol => "symbol",
177        ValueType::Object => "object",
178        ValueType::Function => "function",
179        ValueType::External => "external",
180        ValueType::BigInt => "bigint",
181        _ => "unknown",
182      };
183      format!(
184        "TypeError: Expected return a '{expected_type}' value, but received `{return_value_type_str}`"
185      )
186    }
187    Err(_) => format!("TypeError: Expected return a '{expected_type}' value"),
188  };
189  rspack_error::error!(reason)
190}