#![allow(unsafe_op_in_unsafe_fn)]
use std::any::Any;
use std::future::Future;
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};
use std::task::{Context, Poll, Wake, Waker};
use atomic_waker::AtomicWaker;
use parking_lot::Mutex;
use polars_error::signals::try_raise_keyboard_interrupt;
#[derive(Default)]
struct TaskState {
state: AtomicU8,
}
impl TaskState {
const IDLE: u8 = 0;
const SCHEDULED: u8 = 1;
const RUNNING: u8 = 2;
const NOTIFIED_WHILE_RUNNING: u8 = 3;
fn wake(&self) -> bool {
self.state
.fetch_update(Ordering::Release, Ordering::Relaxed, |state| match state {
Self::SCHEDULED | Self::NOTIFIED_WHILE_RUNNING => None,
Self::RUNNING => Some(Self::NOTIFIED_WHILE_RUNNING),
Self::IDLE => Some(Self::SCHEDULED),
_ => unreachable!("invalid TaskState"),
})
.map(|state| state == Self::IDLE)
.unwrap_or(false)
}
fn start_running(&self) {
assert_eq!(self.state.load(Ordering::Acquire), Self::SCHEDULED);
self.state.store(Self::RUNNING, Ordering::Relaxed);
}
fn reschedule_after_running(&self) -> bool {
self.state
.fetch_update(Ordering::Release, Ordering::Relaxed, |state| match state {
Self::RUNNING => Some(Self::IDLE),
Self::NOTIFIED_WHILE_RUNNING => Some(Self::SCHEDULED),
_ => panic!("TaskState::reschedule_after_running() called on invalid state"),
})
.map(|old_state| old_state == Self::NOTIFIED_WHILE_RUNNING)
.unwrap_or(false)
}
}
enum TaskData<F: Future> {
Empty,
Polling(F, Waker),
Ready(F::Output),
Panic(Box<dyn Any + Send + 'static>),
Cancelled,
Joined,
}
struct Task<F: Future, S, M> {
state: TaskState,
data: Mutex<TaskData<F>>,
join_waker: AtomicWaker,
schedule: S,
metadata: M,
}
impl<'a, F, S, M> Task<F, S, M>
where
F: Future + Send + 'a,
F::Output: Send + 'static,
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
M: Send + Sync + 'static,
{
unsafe fn spawn(future: F, schedule: S, metadata: M) -> Arc<Self> {
let task = Arc::new(Self {
state: TaskState::default(),
data: Mutex::new(TaskData::Empty),
join_waker: AtomicWaker::new(),
schedule,
metadata,
});
let waker = unsafe { Waker::from_raw(std_shim::raw_waker(task.clone())) };
*task.data.try_lock().unwrap() = TaskData::Polling(future, waker);
task
}
fn into_dyn(self: Arc<Self>) -> Arc<dyn DynTask<F::Output, M>> {
let arc: Arc<dyn DynTask<F::Output, M> + 'a> = self;
let arc: Arc<dyn DynTask<F::Output, M>> = unsafe { std::mem::transmute(arc) };
arc
}
}
impl<F, S, M> Wake for Task<F, S, M>
where
F: Future + Send,
F::Output: Send + 'static,
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
M: Send + Sync + 'static,
{
fn wake(self: Arc<Self>) {
if self.state.wake() {
let schedule = self.schedule;
(schedule)(self.into_dyn());
}
}
fn wake_by_ref(self: &Arc<Self>) {
self.clone().wake()
}
}
pub trait DynTask<T, M>: Send + Sync + Runnable<M> + Joinable<T> + Cancellable {}
impl<F, S, M> DynTask<F::Output, M> for Task<F, S, M>
where
F: Future + Send,
F::Output: Send + 'static,
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
M: Send + Sync + 'static,
{
}
pub trait Runnable<M>: Send + Sync {
fn metadata(&self) -> &M;
fn run(self: Arc<Self>) -> bool;
fn schedule(self: Arc<Self>);
}
impl<F, S, M> Runnable<M> for Task<F, S, M>
where
F: Future + Send,
F::Output: Send + 'static,
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
M: Send + Sync + 'static,
{
fn metadata(&self) -> &M {
&self.metadata
}
fn run(self: Arc<Self>) -> bool {
let mut data = self.data.lock();
let poll_result = match &mut *data {
TaskData::Polling(future, waker) => {
self.state.start_running();
let fut = unsafe { Pin::new_unchecked(future) };
let mut ctx = Context::from_waker(waker);
catch_unwind(AssertUnwindSafe(|| {
try_raise_keyboard_interrupt();
fut.poll(&mut ctx)
}))
},
TaskData::Cancelled => return true,
_ => unreachable!("invalid TaskData when polling"),
};
*data = match poll_result {
Err(error) => TaskData::Panic(error),
Ok(Poll::Ready(output)) => TaskData::Ready(output),
Ok(Poll::Pending) => {
drop(data);
if self.state.reschedule_after_running() {
let schedule = self.schedule;
(schedule)(self.into_dyn());
}
return false;
},
};
drop(data);
self.join_waker.wake();
true
}
fn schedule(self: Arc<Self>) {
if self.state.wake() {
(self.schedule)(self.clone().into_dyn());
}
}
}
pub trait Joinable<T>: Send + Sync + Cancellable {
fn poll_join(&self, ctx: &mut Context<'_>) -> Poll<T>;
}
impl<F, S, M> Joinable<F::Output> for Task<F, S, M>
where
F: Future + Send,
F::Output: Send + 'static,
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
M: Send + Sync + 'static,
{
fn poll_join(&self, cx: &mut Context<'_>) -> Poll<F::Output> {
self.join_waker.register(cx.waker());
if let Some(mut data) = self.data.try_lock() {
if matches!(*data, TaskData::Empty | TaskData::Polling(..)) {
return Poll::Pending;
}
match core::mem::replace(&mut *data, TaskData::Joined) {
TaskData::Ready(output) => Poll::Ready(output),
TaskData::Panic(error) => resume_unwind(error),
TaskData::Cancelled => panic!("joined on cancelled task"),
_ => unreachable!("invalid TaskData when joining"),
}
} else {
Poll::Pending
}
}
}
pub trait Cancellable: Send + Sync {
fn cancel(&self);
}
impl<F, S, M> Cancellable for Task<F, S, M>
where
F: Future + Send,
F::Output: Send + 'static,
S: Send + Sync + 'static,
M: Send + Sync + 'static,
{
fn cancel(&self) {
let mut data = self.data.lock();
match *data {
TaskData::Panic(_) | TaskData::Joined => {},
_ => {
*data = TaskData::Cancelled;
if let Some(join_waker) = self.join_waker.take() {
join_waker.wake();
}
},
}
}
}
pub fn spawn<F, S, M>(future: F, schedule: S, metadata: M) -> Arc<dyn DynTask<F::Output, M>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
M: Send + Sync + 'static,
{
unsafe { Task::spawn(future, schedule, metadata) }.into_dyn()
}
pub unsafe fn spawn_with_lifetime<'a, F, S, M>(
future: F,
schedule: S,
metadata: M,
) -> Arc<dyn DynTask<F::Output, M>>
where
F: Future + Send + 'a,
F::Output: Send + 'static,
S: Fn(Arc<dyn Runnable<M>>) + Send + Sync + Copy + 'static,
M: Send + Sync + 'static,
{
Task::spawn(future, schedule, metadata).into_dyn()
}
mod std_shim {
use std::mem::ManuallyDrop;
use std::sync::Arc;
use std::task::{RawWaker, RawWakerVTable, Wake};
#[inline(always)]
pub unsafe fn raw_waker<'a, W: Wake + Send + Sync + 'a>(waker: Arc<W>) -> RawWaker {
#[inline(always)]
unsafe fn clone_waker<W: Wake + Send + Sync>(waker: *const ()) -> RawWaker {
unsafe { Arc::increment_strong_count(waker as *const W) };
RawWaker::new(
waker,
&RawWakerVTable::new(
clone_waker::<W>,
wake::<W>,
wake_by_ref::<W>,
drop_waker::<W>,
),
)
}
unsafe fn wake<W: Wake + Send + Sync>(waker: *const ()) {
let waker = unsafe { Arc::from_raw(waker as *const W) };
<W as Wake>::wake(waker);
}
unsafe fn wake_by_ref<W: Wake + Send + Sync>(waker: *const ()) {
let waker = unsafe { ManuallyDrop::new(Arc::from_raw(waker as *const W)) };
<W as Wake>::wake_by_ref(&waker);
}
unsafe fn drop_waker<W: Wake + Send + Sync>(waker: *const ()) {
unsafe { Arc::decrement_strong_count(waker as *const W) };
}
RawWaker::new(
Arc::into_raw(waker) as *const (),
&RawWakerVTable::new(
clone_waker::<W>,
wake::<W>,
wake_by_ref::<W>,
drop_waker::<W>,
),
)
}
}