#![doc(hidden)]
use std::collections::VecDeque;
use std::sync::Arc;
use core::cell::RefCell;
use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use tracing::{debug_span, Span};
use core::sync::atomic::{AtomicU64, Ordering};
use std::thread::ThreadId;
use std::sync::Mutex;
use std::collections::HashMap;
use lazy_static::lazy_static;
static WAKEUP_HANDLER: Mutex<Option<fn(ThreadId)>> = Mutex::new(None);
#[inline]
pub fn set_wakeup_handler(handler: fn(ThreadId)) {
*WAKEUP_HANDLER.lock().unwrap_or_else(|e| e.into_inner()) = Some(handler);
}
#[inline]
fn signal_wakeup(id: ThreadId) {
if let Some(handler) = *WAKEUP_HANDLER.lock().unwrap_or_else(|e| e.into_inner()) {
handler(id);
}
}
static NEXT_TASK_ID: AtomicU64 = AtomicU64::new(1);
struct Task {
owner_thread: ThreadId,
future: RefCell<Pin<Box<dyn Future<Output = ()>>>>,
span: Span,
}
unsafe impl Send for Task {}
unsafe impl Sync for Task {}
thread_local! {
static RUN_QUEUE: RefCell<VecDeque<Arc<Task>>> = RefCell::new(VecDeque::new());
static STOP: RefCell<bool> = RefCell::new(false);
static THREAD_ID: ThreadId = std::thread::current().id();
}
lazy_static! {
static ref INJECTION_QUEUES: Mutex<HashMap<ThreadId, (Arc<Mutex<VecDeque<Arc<Task>>>>, Arc<core::sync::atomic::AtomicBool>)>> = Mutex::new(HashMap::new());
}
fn get_injection_lock() -> std::sync::MutexGuard<'static, HashMap<ThreadId, (Arc<Mutex<VecDeque<Arc<Task>>>>, Arc<core::sync::atomic::AtomicBool>)>> {
INJECTION_QUEUES.lock().unwrap_or_else(|e| e.into_inner())
}
thread_local! {
static MY_INJECTION: RefCell<Option<(Arc<Mutex<VecDeque<Arc<Task>>>>, Arc<core::sync::atomic::AtomicBool>)>> = RefCell::new(None);
}
fn get_my_injection() -> (Arc<Mutex<VecDeque<Arc<Task>>>>, Arc<core::sync::atomic::AtomicBool>) {
MY_INJECTION.with(|my| {
let mut my = my.borrow_mut();
if let Some(cached) = &*my {
return cached.clone();
}
let current = THREAD_ID.with(|id| *id);
let mut queues = get_injection_lock();
let pair = queues.entry(current).or_insert_with(|| (Arc::new(Mutex::new(VecDeque::new())), Arc::new(core::sync::atomic::AtomicBool::new(false)))).clone();
*my = Some(pair.clone());
pair
})
}
pub fn stop() {
STOP.with(|s| *s.borrow_mut() = true);
}
pub fn is_stopped() -> bool {
STOP.with(|s| *s.borrow())
}
impl Task {
fn poll(self: Arc<Self>) {
let _enter = self.span.enter();
let waker = unsafe { Waker::from_raw(self.clone().raw_waker()) };
let mut cx = Context::from_waker(&waker);
let mut future = self.future.borrow_mut();
match future.as_mut().poll(&mut cx) {
Poll::Ready(_) => {}
Poll::Pending => {}
}
}
fn raw_waker(self: Arc<Self>) -> RawWaker {
let ptr = Arc::into_raw(self) as *const ();
RawWaker::new(ptr, &VTABLE)
}
}
static VTABLE: RawWakerVTable = RawWakerVTable::new(
clone_waker,
wake,
wake_by_ref,
drop_waker,
);
unsafe fn clone_waker(ptr: *const ()) -> RawWaker {
let arc = Arc::from_raw(ptr as *const Task);
let cloned = arc.clone();
let _ = Arc::into_raw(arc);
let new_ptr = Arc::into_raw(cloned) as *const ();
RawWaker::new(new_ptr, &VTABLE)
}
unsafe fn wake(ptr: *const ()) {
let arc = Arc::from_raw(ptr as *const Task);
let owner = arc.owner_thread;
let current = THREAD_ID.with(|id| *id);
if owner == current {
RUN_QUEUE.with(|q| {
q.borrow_mut().push_back(arc);
});
} else {
let (queue, flag) = {
let mut queues = get_injection_lock();
queues.entry(owner).or_insert_with(|| (Arc::new(Mutex::new(VecDeque::new())), Arc::new(core::sync::atomic::AtomicBool::new(false)))).clone()
};
queue.lock().unwrap_or_else(|e| e.into_inner()).push_back(arc);
flag.store(true, Ordering::Release);
signal_wakeup(owner);
}
}
unsafe fn wake_by_ref(ptr: *const ()) {
let arc = Arc::from_raw(ptr as *const Task);
let owner = arc.owner_thread;
let current = THREAD_ID.with(|id| *id);
if owner == current {
RUN_QUEUE.with(|q| {
q.borrow_mut().push_back(arc.clone());
});
} else {
let (queue, flag) = {
let mut queues = get_injection_lock();
queues.entry(owner).or_insert_with(|| (Arc::new(Mutex::new(VecDeque::new())), Arc::new(core::sync::atomic::AtomicBool::new(false)))).clone()
};
queue.lock().unwrap_or_else(|e| e.into_inner()).push_back(arc.clone());
flag.store(true, Ordering::Release);
signal_wakeup(owner);
}
let _ = Arc::into_raw(arc);
}
unsafe fn drop_waker(ptr: *const ()) {
drop(Arc::from_raw(ptr as *const Task));
}
pub fn spawn<F>(future: F)
where
F: Future<Output = ()> + 'static,
{
let _ = get_my_injection();
let id = NEXT_TASK_ID.fetch_add(1, Ordering::Relaxed);
let span = debug_span!("task", id = id);
let task = Arc::new(Task {
owner_thread: THREAD_ID.with(|id| *id),
future: RefCell::new(Box::pin(future)),
span,
});
RUN_QUEUE.with(|q| {
q.borrow_mut().push_back(task);
});
}
pub fn has_pending_tasks() -> bool {
let local_empty = RUN_QUEUE.with(|q| q.borrow().is_empty());
if !local_empty { return true; }
let (queue, flag) = get_my_injection();
if !flag.load(Ordering::Acquire) {
return false;
}
let remote = queue.lock().unwrap_or_else(|e| e.into_inner());
let res = !remote.is_empty();
if !res {
flag.store(false, Ordering::Release);
}
res
}
pub struct Executor;
impl Executor {
pub fn new() -> Self {
let _ = get_my_injection();
Self
}
pub fn run_until_idle(&self) {
self.drain_injection_queue();
let mut processed = 0;
const POLL_BUDGET: usize = 128;
while processed < POLL_BUDGET {
let task = RUN_QUEUE.with(|q| q.borrow_mut().pop_front());
match task {
Some(task) => {
task.poll();
processed += 1;
}
None => break,
}
}
}
fn drain_injection_queue(&self) {
let (queue, flag) = get_my_injection();
let mut remote = queue.lock().unwrap_or_else(|e| e.into_inner());
if !remote.is_empty() {
RUN_QUEUE.with(|q| {
let mut local = q.borrow_mut();
while let Some(task) = remote.pop_front() {
local.push_back(task);
}
});
flag.store(false, Ordering::Release);
} else {
flag.store(false, Ordering::Release);
}
}
}
pub fn yield_now() -> YieldNow {
YieldNow { yielded: false }
}
pub struct YieldNow {
yielded: bool,
}
impl Future for YieldNow {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.yielded {
return Poll::Ready(());
}
self.yielded = true;
cx.waker().wake_by_ref();
Poll::Pending
}
}