use futures_core::future::{Future, LocalFutureObj};
use futures_util::task::AtomicWaker;
use std::cell::{Cell, UnsafeCell};
use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll};
use std::thread::{self, ThreadId};
fumio_utils::local_dl_list! {
mod loc_pending_list {
link TaskPendingLink;
head TaskPendingHead;
member local_pending_link of Task;
}
}
fumio_utils::local_dl_list! {
mod loc_list {
link TaskLink;
head TaskHead;
member local_link of Task;
}
}
fumio_utils::mpsc! {
mod mpsc_list {
link GlobalTaskListLink;
head GlobalTaskListHead;
member global_pending_next of Task;
}
}
#[derive(Debug)]
struct TaskList {
local_all: TaskHead,
local_pending: TaskPendingHead,
global_pending: GlobalTaskListHead,
local_thread: ThreadId,
waker: AtomicWaker,
}
unsafe impl Send for TaskList {}
unsafe impl Sync for TaskList {}
impl TaskList {
fn new() -> Self {
Self {
local_all: TaskHead::new(),
local_pending: TaskPendingHead::new(),
global_pending: GlobalTaskListHead::new(),
local_thread: thread::current().id(),
waker: AtomicWaker::new(),
}
}
fn local_notify(&self, task: &Arc<Task>) {
if task.alive.get() && task.local_pending_link.is_unlinked() {
unsafe { self.local_pending.append(task); }
self.waker.wake();
}
}
fn global_notify(&self, task: &Arc<Task>) {
if !task.queued.swap(true, Ordering::Release) {
self.global_pending.push(task.clone());
self.waker.wake();
} }
fn fetch_global_notifies(&self) {
debug_assert_eq!(thread::current().id(), self.local_thread);
for task in unsafe { self.global_pending.start_pop() } {
task.queued.swap(false, Ordering::Acquire); if task.alive.get() && task.local_pending_link.is_unlinked() {
unsafe { self.local_pending.append(&task); }
}
}
}
fn poll(&self) -> Poll<()> {
struct PollList {
pending: TaskPendingHead,
}
impl Drop for PollList {
fn drop(&mut self) {
while let Some(task) = unsafe { self.pending.pop_back() } {
let task = unsafe { &*task };
if task.alive.get() && task.local_pending_link.is_unlinked() {
unsafe { task.task_list().local_pending.prepend(task); }
}
}
}
}
let mut poll_list = PollList {
pending: TaskPendingHead::new(),
};
unsafe {
poll_list.pending.take_from(&self.local_pending);
while let Some(task) = poll_list.pending.pop_front() {
{ &*task }.local_poll();
}
}
if self.local_all.is_empty() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
impl Drop for TaskList {
fn drop(&mut self) {
assert!(self.local_all.is_empty());
assert!(self.local_pending.is_empty());
}
}
#[derive(Debug)]
pub(super) struct LocalTaskList {
task_list: Arc<TaskList>,
_marker: PhantomData<*mut ()>, }
impl LocalTaskList {
pub fn new() -> Self {
Self {
task_list: Arc::new(TaskList::new()),
_marker: PhantomData,
}
}
pub fn poll(&self, cx: &mut Context<'_>) -> Poll<()> {
self.task_list.waker.register(cx.waker());
self.task_list.fetch_global_notifies();
self.task_list.poll()
}
pub fn add_task(&self, future: LocalFutureObj<'static, ()>) {
let task = Arc::new(Task::new(self.task_list.clone(), future));
unsafe { self.task_list.local_all.append(&task); }
let task = ManuallyDrop::new(task); self.task_list.local_notify(&task);
}
}
impl Drop for LocalTaskList {
fn drop(&mut self) {
while let Some(task) = unsafe { self.task_list.local_all.pop_front() } {
unsafe { &*task }.local_clear();
}
}
}
#[derive(Debug)]
pub(super) struct Task {
task_list: Option<Arc<TaskList>>, local_link: TaskLink,
local_pending_link: TaskPendingLink,
global_pending_next: GlobalTaskListLink, queued: AtomicBool, alive: Cell<bool>,
future: ManuallyDrop<UnsafeCell<Option<LocalFutureObj<'static, ()>>>>,
}
unsafe impl Send for Task {}
unsafe impl Sync for Task {}
impl Task {
fn new(task_list: Arc<TaskList>, future: LocalFutureObj<'static, ()>) -> Self {
Self {
task_list: Some(task_list),
local_link: TaskLink::new(),
local_pending_link: TaskPendingLink::new(),
global_pending_next: GlobalTaskListLink::new(),
queued: AtomicBool::new(false),
alive: Cell::new(true),
future: ManuallyDrop::new(UnsafeCell::new(Some(future))),
}
}
fn task_list(&self) -> &TaskList {
self.task_list.as_ref().expect("not stub task")
}
#[allow(clippy::mut_from_ref)] unsafe fn local_future(&self) -> &mut Option<LocalFutureObj<'static, ()>> {
debug_assert_eq!(thread::current().id(), self.task_list().local_thread);
&mut *self.future.get()
}
fn local_poll(&self) {
struct ClearOnPanic<'a> {
task: Option<&'a Task>,
}
impl Drop for ClearOnPanic<'_> {
fn drop(&mut self) {
if let Some(t) = self.task.take() {
if t.alive.get() {
t.local_clear();
}
}
}
}
debug_assert!(self.alive.get());
let arc_self = ManuallyDrop::new(unsafe { Arc::from_raw(self) }); let waker = futures_util::task::waker_ref(&arc_self);
let mut cx = Context::from_waker(&waker);
let fut = unsafe { self.local_future() }.as_mut().expect("pending futures must be alive");
let fut = unsafe { Pin::new_unchecked(fut) };
let mut cop = ClearOnPanic { task: Some(self) };
if let Poll::Ready(()) = fut.poll(&mut cx) {
self.local_clear();
}
cop.task.take(); }
fn local_clear(&self) {
let this = unsafe { Arc::from_raw(self) };
this.queued.store(true, Ordering::Relaxed);
this.alive.set(false);
unsafe {
this.local_pending_link.unlink();
this.local_link.unlink();
*this.local_future() = None;
}
}
}
impl futures_util::task::ArcWake for Task {
fn wake_by_ref(arc_self: &Arc<Self>) {
let id = thread::current().id();
let task_list = arc_self.task_list();
if id == task_list.local_thread {
task_list.local_notify(arc_self);
} else {
task_list.global_notify(arc_self);
}
}
}
impl Drop for Task {
fn drop(&mut self) {
assert!(!self.alive.get());
debug_assert!(self.task_list.is_none() || unsafe { self.local_future() }.is_none());
}
}