use futures::channel::oneshot;
use futures::channel::oneshot::{Receiver, Sender};
use std::boxed::Box;
use std::cell::{Cell, RefCell};
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::rc::Rc;
use std::task::Poll;
pub fn create_task<F>(future: F) -> (Task, JoinHandle<F::Output>)
where
F: Future + 'static,
{
let (output_tx, output_rx) = oneshot::channel::<F::Output>();
let abort = Rc::new(Cell::new(false));
(
Task::from(GenericTask {
future: Box::pin(future),
output_tx: Some(output_tx),
abort: Rc::clone(&abort),
}),
JoinHandle(RefCell::new(JoinHandleInner::Pending {
output_rx: Box::pin(output_rx),
abort,
})),
)
}
pub struct Task(Pin<Box<dyn Future<Output = ()>>>);
impl Deref for Task {
type Target = Pin<Box<dyn Future<Output = ()>>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Task {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[cfg(test)]
impl Future for Task {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
Future::poll(self.deref_mut().as_mut(), cx)
}
}
impl<F> From<GenericTask<F>> for Task
where
F: Future + 'static,
{
fn from(generic_task: GenericTask<F>) -> Self {
Self(Box::pin(generic_task))
}
}
struct GenericTask<F>
where
F: Future + 'static,
{
future: Pin<Box<F>>,
output_tx: Option<Sender<F::Output>>,
abort: Rc<Cell<bool>>,
}
impl<F> Future for GenericTask<F>
where
F: Future + 'static,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
if self.abort.get() {
Poll::Ready(())
} else {
match Future::poll(self.future.as_mut(), cx) {
Poll::Ready(value) => {
let _ = self.output_tx.take().unwrap().send(value);
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}
}
pub struct JoinHandle<T>(RefCell<JoinHandleInner<T>>);
enum JoinHandleInner<T> {
Pending {
output_rx: Pin<Box<Receiver<T>>>,
abort: Rc<Cell<bool>>,
},
Finished(
Option<T>,
),
Aborted,
}
impl<T> JoinHandle<T> {
fn poll(&self) {
let mut inner = self.0.borrow_mut();
if let JoinHandleInner::Pending {
output_rx,
abort: _,
} = &mut *inner
{
match output_rx.try_recv() {
Ok(Some(value)) => *inner = JoinHandleInner::Finished(Some(value)),
Ok(None) => { }
Err(_) => *inner = JoinHandleInner::Aborted,
}
}
}
pub fn abort(&self) {
let mut inner = self.0.borrow_mut();
if let JoinHandleInner::Pending {
output_rx: _,
abort,
} = &*inner
{
abort.set(true);
*inner = JoinHandleInner::Aborted;
}
}
pub fn is_finished(&self) -> bool {
self.poll();
matches!(&*self.0.borrow(), JoinHandleInner::Finished(_))
}
pub fn is_aborted(&self) -> bool {
self.poll();
matches!(&*self.0.borrow(), JoinHandleInner::Aborted)
}
}
impl<T> Future for JoinHandle<T> {
type Output = Option<T>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let mut inner = self.0.borrow_mut();
match &mut *inner {
JoinHandleInner::Pending {
output_rx,
abort: _,
} => match Future::poll(output_rx.as_mut(), cx) {
Poll::Ready(Ok(value)) => {
*inner = JoinHandleInner::Finished(None);
Poll::Ready(Some(value))
}
Poll::Ready(Err(_)) => {
*inner = JoinHandleInner::Aborted;
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
},
JoinHandleInner::Finished(value) => Poll::Ready(value.take()),
JoinHandleInner::Aborted => Poll::Ready(None),
}
}
}
#[cfg(test)]
#[tokio::test]
async fn test() {
use std::time::Duration;
use tokio::task::LocalSet;
use tokio::time;
let local_set = LocalSet::new();
local_set
.run_until(async {
let (task, join_handle) = create_task(async {
time::sleep(Duration::from_millis(50)).await;
"test"
});
tokio::task::spawn_local(task);
assert!(!join_handle.is_finished());
assert!(!join_handle.is_aborted());
assert_eq!(join_handle.await, Some("test"));
let (task, join_handle) = create_task(async {
time::sleep(Duration::from_millis(50)).await;
"test"
});
tokio::task::spawn_local(task);
time::sleep(Duration::from_millis(100)).await;
assert!(join_handle.is_finished());
assert!(!join_handle.is_aborted());
join_handle.abort();
assert!(join_handle.is_finished());
assert!(!join_handle.is_aborted());
assert_eq!(join_handle.await, Some("test"));
let (task, join_handle) = create_task(async {
time::sleep(Duration::from_millis(50)).await;
"test"
});
tokio::task::spawn_local(task);
assert!(!join_handle.is_finished());
assert!(!join_handle.is_aborted());
join_handle.abort();
assert!(!join_handle.is_finished());
assert!(join_handle.is_aborted());
assert_eq!(join_handle.await, None);
let (task, join_handle) = create_task(async {
time::sleep(Duration::from_millis(500)).await;
"test"
});
let tokio_join_handle = tokio::task::spawn_local(task);
assert!(!join_handle.is_finished());
assert!(!join_handle.is_aborted());
tokio_join_handle.abort();
time::sleep(Duration::from_millis(100)).await;
assert!(!join_handle.is_finished());
assert!(join_handle.is_aborted());
assert_eq!(join_handle.await, None);
let value = Rc::new(Cell::new(0i32));
let (task, join_handle) = create_task({
let value = Rc::clone(&value);
async move {
time::sleep(Duration::from_millis(50)).await;
value.set(1);
"test"
}
});
tokio::task::spawn_local(task);
assert!(!join_handle.is_finished());
assert!(!join_handle.is_aborted());
drop(join_handle);
assert_eq!(value.get(), 0);
time::sleep(Duration::from_millis(100)).await;
assert_eq!(value.get(), 1);
})
.await;
local_set.await;
}