use std::mem::{self, ManuallyDrop};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
pub struct JoinHandle {
state: Arc<Mutex<JoinState>>,
}
enum JoinState {
Running {
waiting_for_abort_signal: Option<Waker>,
waiting_for_abort_to_complete: Option<Waker>,
},
AbortRequested {
waiting_for_abort_to_complete: Option<Waker>,
},
Complete,
}
impl JoinHandle {
pub fn abort(&self) {
let mut state = self.state.lock().unwrap();
match &mut *state {
JoinState::Running {
waiting_for_abort_signal,
waiting_for_abort_to_complete,
} => {
if let Some(task) = waiting_for_abort_signal.take() {
task.wake();
}
*state = JoinState::AbortRequested {
waiting_for_abort_to_complete: waiting_for_abort_to_complete.take(),
};
}
JoinState::AbortRequested { .. } | JoinState::Complete => {}
}
}
pub(crate) fn run<F>(future: F) -> (JoinHandle, impl Future<Output = Option<F::Output>>)
where
F: Future,
{
let handle = JoinHandle {
state: Arc::new(Mutex::new(JoinState::Running {
waiting_for_abort_signal: None,
waiting_for_abort_to_complete: None,
})),
};
let future = JoinHandleFuture {
future: ManuallyDrop::new(future),
state: handle.state.clone(),
};
(handle, future)
}
}
impl Future for JoinHandle {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.state.lock().unwrap();
match &mut *state {
JoinState::Running {
waiting_for_abort_to_complete,
..
}
| JoinState::AbortRequested {
waiting_for_abort_to_complete,
} => {
*waiting_for_abort_to_complete = Some(cx.waker().clone());
Poll::Pending
}
JoinState::Complete => Poll::Ready(()),
}
}
}
struct JoinHandleFuture<F> {
future: ManuallyDrop<F>,
state: Arc<Mutex<JoinState>>,
}
impl<F> Future for JoinHandleFuture<F>
where
F: Future,
{
type Output = Option<F::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let (state, future) = unsafe {
let me = self.get_unchecked_mut();
(&me.state, Pin::new_unchecked(&mut *me.future))
};
{
let mut state = state.lock().unwrap();
match &mut *state {
JoinState::Running {
waiting_for_abort_signal,
..
} => {
*waiting_for_abort_signal = Some(cx.waker().clone());
}
JoinState::AbortRequested { .. } | JoinState::Complete => {
return Poll::Ready(None);
}
}
}
future.poll(cx).map(Some)
}
}
impl<F> Drop for JoinHandleFuture<F> {
fn drop(&mut self) {
unsafe {
ManuallyDrop::drop(&mut self.future);
}
let prev = mem::replace(&mut *self.state.lock().unwrap(), JoinState::Complete);
let task = match prev {
JoinState::Running {
waiting_for_abort_to_complete,
..
}
| JoinState::AbortRequested {
waiting_for_abort_to_complete,
} => waiting_for_abort_to_complete,
JoinState::Complete => None,
};
if let Some(task) = task {
task.wake();
}
}
}
#[cfg(test)]
mod tests {
use super::JoinHandle;
use std::pin::{Pin, pin};
use std::task::{Context, Poll, Waker};
use tokio::sync::oneshot;
fn is_ready<F>(future: Pin<&mut F>) -> bool
where
F: Future,
{
match future.poll(&mut Context::from_waker(Waker::noop())) {
Poll::Ready(_) => true,
Poll::Pending => false,
}
}
#[tokio::test]
async fn abort_in_progress() {
let (tx, rx) = oneshot::channel::<()>();
let (mut handle, future) = JoinHandle::run(rx);
let mut handle = Pin::new(&mut handle);
{
let mut future = pin!(future);
assert!(!is_ready(future.as_mut()));
assert!(!is_ready(handle.as_mut()));
handle.abort();
assert!(is_ready(future.as_mut()));
assert!(!is_ready(handle.as_mut()));
assert!(!tx.is_closed());
}
assert!(is_ready(handle.as_mut()));
assert!(tx.is_closed());
}
#[tokio::test]
async fn abort_complete() {
let (tx, rx) = oneshot::channel::<()>();
let (mut handle, future) = JoinHandle::run(rx);
let mut handle = Pin::new(&mut handle);
tx.send(()).unwrap();
assert!(!is_ready(handle.as_mut()));
{
let mut future = pin!(future);
assert!(is_ready(future.as_mut()));
assert!(!is_ready(handle.as_mut()));
}
assert!(is_ready(handle.as_mut()));
handle.abort();
assert!(is_ready(handle.as_mut()));
}
#[tokio::test]
async fn abort_dropped() {
let (tx, rx) = oneshot::channel::<()>();
let (mut handle, future) = JoinHandle::run(rx);
let mut handle = Pin::new(&mut handle);
drop(future);
assert!(is_ready(handle.as_mut()));
handle.abort();
assert!(is_ready(handle.as_mut()));
assert!(tx.is_closed());
}
#[tokio::test]
async fn await_completion() {
let (tx, rx) = oneshot::channel::<()>();
tx.send(()).unwrap();
let (handle, future) = JoinHandle::run(rx);
let task = tokio::task::spawn(future);
handle.await;
task.await.unwrap();
}
#[tokio::test]
async fn await_abort() {
let (tx, rx) = oneshot::channel::<()>();
tx.send(()).unwrap();
let (handle, future) = JoinHandle::run(rx);
handle.abort();
let task = tokio::task::spawn(future);
handle.await;
task.await.unwrap();
}
}