js_function_promisify/
callback.rs

1use core::cell::RefCell;
2use js_sys::Function;
3use std::fmt::Debug;
4use std::future::Future;
5use std::rc::Rc;
6use std::task::Poll;
7use std::task::Waker;
8use wasm_bindgen::prelude::Closure;
9use wasm_bindgen::JsValue;
10
11/// A `Callback<F>` is a wrapper around a `wasm_bindgen::prelude::Closure<F>` which supports TODO:
12#[derive(Debug)]
13pub struct Callback<F: 'static + ?Sized> {
14  inner: Rc<RefCell<CallbackInner<F>>>,
15}
16
17impl<F: 'static + ?Sized> Callback<F> {
18  pub fn new<X>(closure: X) -> Callback<F>
19  where
20    Self: From<X>,
21  {
22    Self::from(closure)
23  }
24
25  pub fn as_function(&self) -> Function {
26    let js_func: JsValue = self
27      .inner
28      .borrow()
29      .cb
30      .as_ref()
31      .unwrap()
32      .as_ref()
33      .as_ref()
34      .into();
35    let func: Function = js_func.into();
36    func
37  }
38
39  pub fn as_closure(&self) -> Rc<Closure<F>> {
40    Rc::clone(self.inner.borrow().cb.as_ref().unwrap())
41  }
42}
43
44/// The Default impl for Callback creates a single-arg callback, whose Result is always Ok.
45impl Default for Callback<dyn FnMut(JsValue)> {
46  fn default() -> Self {
47    Self::from(|data| Ok(data))
48  }
49}
50
51impl Callback<dyn FnMut(JsValue, JsValue)> {
52  /// Creates a node-style callback with the args `(err, data)`. If err is null or undefined,
53  /// the Result is Ok(data). Otherwise, it is Err(err).
54  pub fn default_node() -> Self {
55    Self::from(|err: JsValue, data: JsValue| {
56      if err.is_null() || err.is_undefined() {
57        return Ok(data);
58      }
59      Err(err)
60    })
61  }
62}
63
64/// Standard Future impl for Callback<T>
65impl<F: 'static + ?Sized> Future for Callback<F> {
66  type Output = Result<JsValue, JsValue>;
67
68  fn poll(
69    self: std::pin::Pin<&mut Self>,
70    cx: &mut std::task::Context<'_>,
71  ) -> std::task::Poll<Self::Output> {
72    let mut inner = self.inner.borrow_mut();
73    if let Some(val) = inner.result.take() {
74      return Poll::Ready(val);
75    }
76    inner.task = Some(cx.waker().clone());
77    Poll::Pending
78  }
79}
80
81/// A utility macro for generating every possible implementation of `From<A> for Callback`.
82macro_rules! from_impl {
83  // The main arm of this macro. Generates a single From impl for Callback.
84  // a - The list of parameter types that FnMut A takes.
85  // alist - The argument list of A.
86  (($($a:ty),*), ($($alist:ident),*)) => {
87    impl<A> From<A> for Callback<dyn FnMut($($a,)*)>
88    where
89      A: 'static + FnOnce($($a,)*) -> Result<JsValue, JsValue>,
90    {
91      fn from(cb: A) -> Self {
92        let inner = CallbackInner::new();
93        let state = Rc::clone(&inner);
94        let closure = Closure::once(move |$($alist),*| CallbackInner::finish(&state, cb($($alist),*)));
95        let ptr = Rc::new(closure);
96        inner.borrow_mut().cb = Some(ptr);
97        Callback { inner }
98      }
99    }
100  };
101  // Shorthand for the main arm. Based on the argument list, generate the parameter types (always JsValue) for that list.
102  (($($a:ident,)*)) => {
103    from_impl!(($(from_impl!(@rep $a JsValue)),*), ($($a),*));
104  };
105  // For a list of identifiers, recursively generates a From impl for that list and every list with less args.
106  ($head:ident $($tail:tt)*) => {
107    // Generate a From impl for the full set of arguments.
108    from_impl!(($head, $($tail,)*));
109    // Recurse inwards, generating the same definitions with one less argument.
110    from_impl!($($tail)*);
111  };
112  // Utility for replacing anything with a type.
113  (@rep $_t:tt $sub:ty) => {
114    $sub
115  };
116  // Empty arms for handling the end of recursion.
117  () => {
118    from_impl!(());
119  };
120}
121
122from_impl!(a0 a1 a2 a3 a4 a5 a6); // Generate From impls for each list of arguments, up to 7.
123
124#[derive(Debug)]
125pub struct CallbackInner<F: 'static + ?Sized> {
126  cb: Option<Rc<Closure<F>>>,
127  result: Option<Result<JsValue, JsValue>>,
128  task: Option<Waker>,
129}
130
131impl<F: 'static + ?Sized> CallbackInner<F> {
132  pub fn new() -> Rc<RefCell<CallbackInner<F>>> {
133    Rc::new(RefCell::new(CallbackInner {
134      cb: None,
135      task: None,
136      result: None,
137    }))
138  }
139
140  pub fn finish(state: &RefCell<CallbackInner<F>>, val: Result<JsValue, JsValue>) {
141    let task = {
142      let mut state = state.borrow_mut();
143      debug_assert!(state.result.is_none());
144      debug_assert!(state.cb.is_some());
145      drop(state.cb.take());
146      state.result = Some(val);
147      state.task.take()
148    };
149    if let Some(task) = task {
150      task.wake()
151    }
152  }
153}
154
155#[cfg(test)]
156mod tests {
157  use crate::Callback;
158  use js_sys::Function;
159  use std::rc::Rc;
160  use wasm_bindgen::prelude::*;
161  use wasm_bindgen::JsCast;
162  use wasm_bindgen_test::*;
163  use web_sys::{window, IdbOpenDbRequest};
164
165  wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
166
167  /// Not quite as beatiful as the CallbackPair test, but still important to enumerate every expected valid From impl.
168  #[wasm_bindgen_test]
169  #[rustfmt::skip]
170  fn should_compile_with_any_args() {
171    let _r = Callback::new(|| Ok("".into()));
172    let _r = Callback::new(|_a| Ok("".into()));
173    let _r = Callback::new(|_a, _b| Ok("".into()));
174    let _r = Callback::new(|_a, _b, _c| Ok("".into()));
175    let _r = Callback::new(|_a, _b, _c, _d| Ok("".into()));
176    let _r = Callback::new(|_a, _b, _c, _d, _e| Ok("".into()));
177    let _r = Callback::new(|_a, _b, _c, _d, _e, _f| Ok("".into()));
178    let _r = Callback::new(|_a, _b, _c, _d, _e, _f, _g| Ok("".into()));
179  }
180
181  #[wasm_bindgen_test]
182  async fn inner_dropped_after_await() {
183    let future = Callback::default();
184    let req: IdbOpenDbRequest = window()
185      .expect("window not available")
186      .indexed_db()
187      .unwrap()
188      .expect("idb not available")
189      .open("my_db")
190      .expect("Failed to get idb request");
191    req.set_onerror(Some(&future.as_function()));
192    let inner_ref = {
193      let weak_ref = Rc::downgrade(&future.inner);
194      req.set_onsuccess(Some(future.as_closure().as_ref().as_ref().unchecked_ref()));
195      assert_eq!(weak_ref.upgrade().is_some(), true); // Assert inner_ref `Some`
196      weak_ref
197    };
198    assert_eq!(inner_ref.upgrade().is_some(), true); // Assert inner_ref `Some`
199    future.await.unwrap();
200    assert_eq!(inner_ref.upgrade().is_none(), true); // Assert inner_ref `None`
201  }
202
203  #[wasm_bindgen_test]
204  async fn closure_dropped_after_await() {
205    let future = Callback::default();
206    let req: IdbOpenDbRequest = window()
207      .expect("window not available")
208      .indexed_db()
209      .unwrap()
210      .expect("idb not available")
211      .open("my_db")
212      .expect("Failed to get idb request");
213    req.set_onerror(Some(future.as_closure().as_ref().as_ref().unchecked_ref()));
214    let resolve_ref = {
215      let weak_ref = Rc::downgrade(&future.as_closure());
216      req.set_onsuccess(Some(future.as_closure().as_ref().as_ref().unchecked_ref()));
217      assert_eq!(weak_ref.upgrade().is_some(), true); // Assert resolve_ref `Some`
218      weak_ref
219    };
220    assert_eq!(resolve_ref.upgrade().is_some(), true); // Assert resolve_ref `Some`
221    future.await.unwrap();
222    assert_eq!(resolve_ref.upgrade().is_none(), true); // Assert resolve_ref `None`
223  }
224
225  #[wasm_bindgen(
226    inline_js = "export function extern_node_success_null(cb) { cb(null, 'success') }; 
227    export function extern_node_success_undefined(cb) { cb(undefined, 'success') };
228    export function extern_node_failure(cb) { cb('failure', 'success') };"
229  )]
230  extern "C" {
231    fn extern_node_success_null(cb: &Function);
232    fn extern_node_success_undefined(cb: &Function);
233    fn extern_node_failure(cb: &Function);
234  }
235
236  #[wasm_bindgen_test]
237  async fn node_ok_if_arg0_null() {
238    let future = Callback::default_node();
239    extern_node_success_null(future.as_function().as_ref());
240    let result = future.await;
241    assert_eq!(result.is_ok(), true); // Assert is `Ok`
242    assert_eq!(result.unwrap(), "success");
243  }
244
245  #[wasm_bindgen_test]
246  async fn node_ok_if_arg0_undefined() {
247    let future = Callback::default_node();
248    extern_node_success_undefined(future.as_function().as_ref());
249    let result = future.await;
250    assert_eq!(result.is_ok(), true); // Assert is `Ok`
251    assert_eq!(result.unwrap(), "success");
252  }
253
254  #[wasm_bindgen_test]
255  async fn node_err_if_arg0_defined() {
256    let future = Callback::default_node();
257    extern_node_failure(future.as_function().as_ref());
258    let result = future.await;
259    assert_eq!(result.is_err(), true); // Assert is `Err`
260    assert_eq!(result.unwrap_err(), "failure");
261  }
262}