use std::cell::{Cell, UnsafeCell};
use std::sync::Arc;
use std::sync::atomic::{AtomicPtr, Ordering};
use crate::task;
thread_local! {
static CTX_CROSS_WAKE: Cell<*const Arc<CrossWakeContext>> =
const { Cell::new(std::ptr::null()) };
}
pub(crate) fn install_cross_wake(ctx: &Arc<CrossWakeContext>) -> CrossWakeGuard {
let prev = CTX_CROSS_WAKE.with(|c| c.replace(std::ptr::from_ref(ctx)));
CrossWakeGuard { prev }
}
pub(crate) struct CrossWakeGuard {
prev: *const Arc<CrossWakeContext>,
}
impl Drop for CrossWakeGuard {
fn drop(&mut self) {
CTX_CROSS_WAKE.with(|c| c.set(self.prev));
}
}
pub(crate) fn cross_wake_context() -> Option<Arc<CrossWakeContext>> {
CTX_CROSS_WAKE.with(|c| {
let ptr = c.get();
if ptr.is_null() {
None
} else {
Some(unsafe { (*ptr).clone() })
}
})
}
pub(crate) struct CrossWakeQueue {
head: UnsafeCell<*mut u8>,
tail: AtomicPtr<u8>,
stub: *mut AtomicPtr<u8>,
}
unsafe impl Send for CrossWakeQueue {}
unsafe impl Sync for CrossWakeQueue {}
impl CrossWakeQueue {
pub(crate) fn new() -> Self {
let stub = Box::into_raw(Box::new(AtomicPtr::new(std::ptr::null_mut())));
let stub_as_node = stub.cast::<u8>();
Self {
head: UnsafeCell::new(stub_as_node),
tail: AtomicPtr::new(stub_as_node),
stub,
}
}
#[inline]
fn stub_ptr(&self) -> *mut u8 {
self.stub.cast::<u8>()
}
#[inline]
unsafe fn next_of(&self, node: *mut u8) -> &AtomicPtr<u8> {
if node == self.stub_ptr() {
unsafe { &*self.stub }
} else {
unsafe { &*task::cross_next(node) }
}
}
}
impl Drop for CrossWakeQueue {
fn drop(&mut self) {
unsafe { drop(Box::from_raw(self.stub)) };
}
}
impl CrossWakeQueue {
pub(crate) unsafe fn push(&self, task_ptr: *mut u8) {
unsafe { self.next_of(task_ptr) }.store(std::ptr::null_mut(), Ordering::Relaxed);
let prev = self.tail.swap(task_ptr, Ordering::AcqRel);
unsafe { self.next_of(prev) }.store(task_ptr, Ordering::Release);
}
pub(crate) fn pop(&self) -> Option<*mut u8> {
let head_ref = unsafe { &mut *self.head.get() };
let mut head = *head_ref;
let mut next = unsafe { self.next_of(head) }.load(Ordering::Acquire);
let stub = self.stub_ptr();
if head == stub {
if next.is_null() {
return None; }
*head_ref = next;
head = next;
next = unsafe { self.next_of(head) }.load(Ordering::Acquire);
}
if !next.is_null() {
*head_ref = next;
return Some(head);
}
let tail = self.tail.load(Ordering::Acquire);
if head != tail {
return None;
}
unsafe { self.push(stub) };
next = unsafe { self.next_of(head) }.load(Ordering::Acquire);
if !next.is_null() {
*head_ref = next;
return Some(head);
}
None
}
}
pub(crate) struct CrossWakeContext {
pub(crate) queue: CrossWakeQueue,
pub(crate) mio_waker: Arc<mio::Waker>,
pub(crate) parked: std::sync::atomic::AtomicBool,
}
unsafe impl Send for CrossWakeContext {}
unsafe impl Sync for CrossWakeContext {}
pub(crate) unsafe fn wake_task_cross_thread(task_ptr: *mut u8, ctx: &CrossWakeContext) {
if unsafe { task::is_completed(task_ptr) } {
return;
}
if !unsafe { task::try_set_queued(task_ptr) } {
return;
}
unsafe { ctx.queue.push(task_ptr) };
if ctx.parked.load(Ordering::Acquire) {
let _ = ctx.mio_waker.wake();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::task::Task;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
struct Noop;
impl Future for Noop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
fn make_task() -> *mut u8 {
let task = Box::new(Task::new_boxed(Noop, 0));
Box::into_raw(task) as *mut u8
}
unsafe fn free(ptr: *mut u8) {
unsafe { task::free_task(ptr) };
}
#[test]
fn queue_push_pop_single() {
let q = CrossWakeQueue::new();
let t1 = make_task();
unsafe { q.push(t1) };
assert_eq!(q.pop(), Some(t1));
assert_eq!(q.pop(), None);
unsafe { free(t1) };
}
#[test]
fn queue_push_pop_multiple() {
let q = CrossWakeQueue::new();
let t1 = make_task();
let t2 = make_task();
let t3 = make_task();
unsafe { q.push(t1) };
unsafe { q.push(t2) };
unsafe { q.push(t3) };
assert_eq!(q.pop(), Some(t1));
assert_eq!(q.pop(), Some(t2));
assert_eq!(q.pop(), Some(t3));
assert_eq!(q.pop(), None);
unsafe { free(t1) };
unsafe { free(t2) };
unsafe { free(t3) };
}
#[test]
fn queue_interleaved_push_pop() {
let q = CrossWakeQueue::new();
let t1 = make_task();
let t2 = make_task();
unsafe { q.push(t1) };
assert_eq!(q.pop(), Some(t1));
unsafe { q.push(t2) };
assert_eq!(q.pop(), Some(t2));
assert_eq!(q.pop(), None);
unsafe { free(t1) };
unsafe { free(t2) };
}
#[test]
fn queue_empty() {
let q = CrossWakeQueue::new();
assert_eq!(q.pop(), None);
assert_eq!(q.pop(), None);
}
#[test]
fn queue_reuse_after_drain() {
let q = CrossWakeQueue::new();
let t1 = make_task();
for _ in 0..100 {
unsafe { q.push(t1) };
assert_eq!(q.pop(), Some(t1));
}
assert_eq!(q.pop(), None);
unsafe { free(t1) };
}
}