use std::any::Any;
use std::cell::UnsafeCell;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
pub(crate) const STATE_IDLE: u32 = 0;
pub(crate) const STATE_SCHEDULED: u32 = 1;
pub(crate) const STATE_RUNNING: u32 = 2;
pub(crate) const STATE_COMPLETED: u32 = 3;
pub(crate) const STATE_CANCELLED: u32 = 4;
#[derive(Debug)]
pub enum JoinError {
Cancelled,
Panic(Box<dyn Any + Send + 'static>),
}
impl std::fmt::Display for JoinError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
JoinError::Cancelled => write!(f, "task was cancelled"),
JoinError::Panic(_) => write!(f, "task panicked"),
}
}
}
impl std::error::Error for JoinError {}
pub(crate) struct TaskVtable {
pub poll: unsafe fn(body: *mut (), header: &TaskHeader, cx: &mut Context<'_>) -> bool,
pub drop_body: unsafe fn(body: *mut ()),
}
struct TaskBody<F> {
future: Pin<Box<F>>,
}
unsafe fn body_poll<F, T>(body_ptr: *mut (), header: &TaskHeader, cx: &mut Context<'_>) -> bool
where
F: Future<Output = T>,
T: Send + 'static,
{
let body = &mut *(body_ptr as *mut TaskBody<F>);
match body.future.as_mut().poll(cx) {
Poll::Ready(val) => {
*header.output.get() = Some(Box::new(val) as Box<dyn Any + Send>);
true
}
Poll::Pending => false,
}
}
unsafe fn body_drop<F>(ptr: *mut ()) {
drop(Box::from_raw(ptr as *mut TaskBody<F>));
}
fn make_vtable<F, T>() -> &'static TaskVtable
where
F: Future<Output = T>,
T: Send + 'static,
{
&TaskVtable {
poll: body_poll::<F, T>,
drop_body: body_drop::<F>,
}
}
pub(crate) struct TaskHeader {
pub state: AtomicU32,
pub vtable: &'static TaskVtable,
pub join_waker: Mutex<Option<Waker>>,
pub body_ptr: UnsafeCell<*mut ()>,
pub output: UnsafeCell<Option<Box<dyn Any + Send>>>,
}
unsafe impl Send for TaskHeader {}
unsafe impl Sync for TaskHeader {}
pub(crate) struct Task {
pub(crate) header: Arc<TaskHeader>,
}
impl Task {
pub(crate) fn new<F, T>(future: F) -> (Task, JoinHandle<T>)
where
F: Future<Output = T> + 'static,
T: Send + 'static,
{
let body: Box<TaskBody<F>> = Box::new(TaskBody {
future: Box::pin(future),
});
let body_ptr = Box::into_raw(body) as *mut ();
let header = Arc::new(TaskHeader {
state: AtomicU32::new(STATE_SCHEDULED),
vtable: make_vtable::<F, T>(),
join_waker: Mutex::new(None),
body_ptr: UnsafeCell::new(body_ptr),
output: UnsafeCell::new(None),
});
let join_arc = Arc::clone(&header);
let task = Task { header };
let jh = JoinHandle {
header: join_arc,
_marker: std::marker::PhantomData,
};
(task, jh)
}
pub(crate) fn poll_task(&self, cx: &mut Context<'_>) -> bool {
let h = &self.header;
h.state.store(STATE_RUNNING, Ordering::Release);
let body_ptr = unsafe { *h.body_ptr.get() };
debug_assert!(!body_ptr.is_null(), "poll_task called on freed body");
let completed = unsafe { (h.vtable.poll)(body_ptr, h, cx) };
if completed {
unsafe {
(h.vtable.drop_body)(body_ptr);
*h.body_ptr.get() = std::ptr::null_mut();
}
h.state.store(STATE_COMPLETED, Ordering::Release);
let waker = h.join_waker.lock().unwrap().take();
if let Some(w) = waker {
w.wake();
}
} else {
h.state.store(STATE_IDLE, Ordering::Release);
}
completed
}
pub(crate) fn cancel(self) {
let h = &self.header;
let body_ptr = unsafe { *h.body_ptr.get() };
if !body_ptr.is_null() {
unsafe {
(h.vtable.drop_body)(body_ptr);
*h.body_ptr.get() = std::ptr::null_mut();
}
}
h.state.store(STATE_CANCELLED, Ordering::Release);
let waker = h.join_waker.lock().unwrap().take();
if let Some(w) = waker {
w.wake();
}
}
}
pub struct JoinHandle<T> {
pub(crate) header: Arc<TaskHeader>,
_marker: std::marker::PhantomData<T>,
}
impl<T: Send + 'static> JoinHandle<T> {
pub fn abort(&self) {
let _ = self.header.state.compare_exchange(
STATE_IDLE,
STATE_CANCELLED,
Ordering::AcqRel,
Ordering::Relaxed,
);
let _ = self.header.state.compare_exchange(
STATE_SCHEDULED,
STATE_CANCELLED,
Ordering::AcqRel,
Ordering::Relaxed,
);
}
}
impl<T: Send + 'static> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let state = self.header.state.load(Ordering::Acquire);
if state == STATE_COMPLETED {
return self.take_output();
}
if state == STATE_CANCELLED {
return Poll::Ready(Err(JoinError::Cancelled));
}
let mut guard = self.header.join_waker.lock().unwrap();
let state = self.header.state.load(Ordering::Acquire);
match state {
STATE_COMPLETED => {
drop(guard);
self.take_output()
}
STATE_CANCELLED => {
drop(guard);
Poll::Ready(Err(JoinError::Cancelled))
}
_ => {
*guard = Some(cx.waker().clone());
Poll::Pending
}
}
}
}
impl<T: Send + 'static> JoinHandle<T> {
fn take_output(self: Pin<&mut Self>) -> Poll<Result<T, JoinError>> {
let boxed = unsafe { (*self.header.output.get()).take() };
match boxed {
Some(any_val) => match any_val.downcast::<T>() {
Ok(val) => Poll::Ready(Ok(*val)),
Err(_) => Poll::Ready(Err(JoinError::Cancelled)),
},
None => Poll::Ready(Err(JoinError::Cancelled)), }
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicBool;
#[test]
fn task_new_initial_state() {
let (task, _jh) = Task::new(async { 42u32 });
assert_eq!(task.header.state.load(Ordering::Acquire), STATE_SCHEDULED);
}
#[test]
fn join_error_display() {
assert_eq!(JoinError::Cancelled.to_string(), "task was cancelled");
assert!(JoinError::Panic(Box::new("x"))
.to_string()
.contains("panicked"));
}
#[test]
fn abort_from_idle_sets_cancelled() {
let (task, jh) = Task::new(async { 1u32 });
task.header.state.store(STATE_IDLE, Ordering::Release);
jh.abort();
assert_eq!(task.header.state.load(Ordering::Acquire), STATE_CANCELLED);
}
#[test]
fn cancel_drops_future() {
let dropped = Arc::new(AtomicBool::new(false));
let d = dropped.clone();
struct Bomb(Arc<AtomicBool>);
impl Drop for Bomb {
fn drop(&mut self) {
self.0.store(true, Ordering::SeqCst);
}
}
impl Future for Bomb {
type Output = ();
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<()> {
Poll::Pending
}
}
let (task, _jh) = Task::new(Bomb(d));
task.cancel();
assert!(
dropped.load(Ordering::SeqCst),
"future must be dropped on cancel"
);
}
#[test]
fn join_error_panic_display() {
let err = JoinError::Panic(Box::new("boom"));
let s = err.to_string();
assert!(s.contains("panic"));
}
#[test]
fn join_error_cancelled_display() {
let err = JoinError::Cancelled;
let s = err.to_string();
assert!(s.contains("cancel") || s.contains("Cancel"));
}
#[test]
fn abort_from_scheduled_sets_cancelled() {
let (_task, jh) = Task::new(async { 1u32 });
jh.abort();
assert_eq!(
jh.header.state.load(Ordering::Acquire),
STATE_CANCELLED
);
}
#[test]
fn task_header_initial_state_is_scheduled() {
let (task, _jh) = Task::new(async { 0u8 });
assert_eq!(
task.header.state.load(Ordering::Acquire),
STATE_SCHEDULED
);
}
#[test]
fn cancel_sets_state_to_cancelled() {
let (task, _jh) = Task::new(async { 0u8 });
task.cancel();
}
#[test]
fn abort_completed_task_has_no_effect() {
let (task, jh) = Task::new(async { 99u32 });
task.header.state.store(STATE_COMPLETED, Ordering::Release);
jh.abort(); assert_eq!(
jh.header.state.load(Ordering::Acquire),
STATE_COMPLETED
);
}
#[test]
fn state_constants_distinct() {
let states = [
STATE_IDLE,
STATE_SCHEDULED,
STATE_RUNNING,
STATE_COMPLETED,
STATE_CANCELLED,
];
let unique: std::collections::HashSet<u32> = states.iter().cloned().collect();
assert_eq!(unique.len(), states.len());
}
#[test]
fn join_error_debug_format() {
let err = JoinError::Cancelled;
let s = format!("{err:?}");
assert!(!s.is_empty());
}
#[test]
fn task_new_creates_join_handle_with_same_header() {
let (task, jh) = Task::new(async { 0u32 });
assert!(Arc::ptr_eq(&task.header, &jh.header));
}
#[test]
fn abort_from_idle_state_succeeds() {
let (task, jh) = Task::new(async { 0u32 });
task.header.state.store(STATE_IDLE, Ordering::Release);
jh.abort();
assert_eq!(task.header.state.load(Ordering::Acquire), STATE_CANCELLED);
}
#[test]
fn multiple_aborts_are_idempotent() {
let (_task, jh) = Task::new(async { 0u32 });
jh.abort();
jh.abort();
jh.abort();
}
}