use futures::channel::oneshot::{self, Sender};
use futures::{select, Future, FutureExt};
use std::rc::Rc;
use std::sync::Mutex;
use std::task::{Poll, Waker};
pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
let (cancel_sender, mut cancel_receiver) = oneshot::channel();
let join_handle = JoinHandle::new(cancel_sender);
let join_handle_clone = join_handle.clone();
wasm_bindgen_futures::spawn_local(async move {
select! {
result = future.fuse() => join_handle_clone.set_result(result),
_ = cancel_receiver => ()
}
});
join_handle
}
#[derive(Debug)]
#[non_exhaustive]
pub struct JoinError {}
impl JoinError {
pub fn is_cancelled(&self) -> bool {
true
}
}
#[derive(Debug)]
pub struct JoinHandle<T> {
state: Rc<Mutex<State<T>>>,
cancel_sender: Mutex<Option<Sender<()>>>,
}
impl<T> JoinHandle<T> {
fn new(cancel_sender: Sender<()>) -> Self {
JoinHandle {
state: Rc::new(Mutex::new(State::new())),
cancel_sender: Mutex::new(Some(cancel_sender)),
}
}
pub fn abort(&self) {
self.state.lock().unwrap().set_result(Err(JoinError {}));
if let Some(sender) = self.cancel_sender.lock().unwrap().take() {
let _ = sender.send(());
}
}
pub fn is_finished(&self) -> bool {
self.state.lock().unwrap().is_finished()
}
fn set_result(&self, value: T) {
self.state.lock().unwrap().set_result(Ok(value));
}
fn clone(&self) -> Self {
JoinHandle {
state: self.state.clone(),
cancel_sender: Mutex::new(None),
}
}
}
#[derive(Debug)]
struct State<T> {
result: Option<Result<T, JoinError>>,
waker: Option<Waker>,
}
impl<T> State<T> {
fn new() -> Self {
State {
result: None,
waker: None,
}
}
fn is_finished(&self) -> bool {
self.result.is_some()
}
fn set_result(&mut self, value: Result<T, JoinError>) {
if self.result.is_none() {
self.result = Some(value);
self.wake();
}
}
fn wake(&mut self) {
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
fn update_waker(&mut self, waker: &Waker) {
if let Some(current_waker) = &self.waker {
if !waker.will_wake(current_waker) {
self.waker = Some(waker.clone());
}
} else {
self.waker = Some(waker.clone())
}
}
}
impl<T> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let mut state = self.state.lock().unwrap();
if let Some(value) = state.result.take() {
Poll::Ready(value)
} else {
state.update_waker(cx.waker());
Poll::Pending
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use wasm_bindgen_test::wasm_bindgen_test;
use crate::{sleep, spawn};
#[wasm_bindgen_test]
async fn test_spawn() {
let task_1 = spawn(async { 1 });
let task_2 = spawn(async { 2 });
sleep(Duration::from_secs(1)).await;
assert!(task_1.is_finished());
assert!(task_2.is_finished());
assert_eq!(task_1.await.unwrap(), 1);
assert_eq!(task_2.await.unwrap(), 2);
}
#[wasm_bindgen_test]
async fn test_abort() {
let task = spawn(async {
sleep(Duration::from_secs(10)).await;
1
});
task.abort();
assert!(task.await.unwrap_err().is_cancelled());
}
}