use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{RawWaker, RawWakerVTable, Waker};
use super::scheduler::GlobalQueue;
use super::task::{TaskHeader, STATE_IDLE, STATE_SCHEDULED};
static TASK_WAKER_VTABLE: RawWakerVTable =
RawWakerVTable::new(clone_waker, wake, wake_by_ref, drop_waker);
pub(crate) fn make_waker(
header: Arc<TaskHeader>,
queue: Arc<GlobalQueue>,
) -> Waker {
make_waker_with_notifier(header, queue, None)
}
pub(crate) fn make_waker_with_notifier(
header: Arc<TaskHeader>,
queue: Arc<GlobalQueue>,
notifier: Option<Arc<WorkerNotifier>>,
) -> Waker {
let data = Arc::new(WakerData {
header,
queue,
notifier,
});
let ptr = Arc::into_raw(data) as *const ();
let raw = RawWaker::new(ptr, &TASK_WAKER_VTABLE);
unsafe { Waker::from_raw(raw) }
}
pub(crate) struct WorkerNotifier {
wake_fds: std::sync::Mutex<Vec<i32>>,
next: AtomicUsize,
}
impl WorkerNotifier {
pub(crate) fn new() -> Self {
Self {
wake_fds: std::sync::Mutex::new(Vec::new()),
next: AtomicUsize::new(0),
}
}
pub(crate) fn add_fd(&self, fd: i32) {
self.wake_fds.lock().unwrap().push(fd);
}
#[cfg(unix)]
pub(crate) fn notify_one(&self) {
let fds = self.wake_fds.lock().unwrap();
if fds.is_empty() {
return;
}
let idx = self.next.fetch_add(1, Ordering::Relaxed) % fds.len();
let fd = fds[idx];
drop(fds);
unsafe {
let b: u8 = 1;
libc::write(fd, &b as *const u8 as *const _, 1);
}
}
#[cfg(not(unix))]
pub(crate) fn notify_one(&self) {}
}
struct WakerData {
header: Arc<TaskHeader>,
queue: Arc<GlobalQueue>,
notifier: Option<Arc<WorkerNotifier>>,
}
#[inline]
unsafe fn data_ref(ptr: *const ()) -> std::mem::ManuallyDrop<Arc<WakerData>> {
std::mem::ManuallyDrop::new(Arc::from_raw(ptr as *const WakerData))
}
unsafe fn clone_waker(ptr: *const ()) -> RawWaker {
let data = data_ref(ptr);
let cloned = Arc::clone(&*data);
let new_ptr = Arc::into_raw(cloned) as *const ();
RawWaker::new(new_ptr, &TASK_WAKER_VTABLE)
}
unsafe fn wake(ptr: *const ()) {
let data = Arc::from_raw(ptr as *const WakerData);
schedule_task(&data);
}
unsafe fn wake_by_ref(ptr: *const ()) {
let data = data_ref(ptr);
schedule_task(&data);
}
unsafe fn drop_waker(ptr: *const ()) {
drop(Arc::from_raw(ptr as *const WakerData));
}
fn schedule_task(data: &WakerData) {
let header = &data.header;
let prev = header.state.compare_exchange(
STATE_IDLE,
STATE_SCHEDULED,
Ordering::AcqRel,
Ordering::Relaxed,
);
if prev.is_ok() {
data.queue.push_header(Arc::clone(header));
if let Some(ref notifier) = data.notifier {
notifier.notify_one();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::executor::task::{Task, STATE_IDLE, STATE_SCHEDULED};
use std::sync::atomic::Ordering;
fn make_test_waker(task: &Task) -> (Waker, Arc<GlobalQueue>) {
let q = Arc::new(GlobalQueue::new());
let w = make_waker(Arc::clone(&task.header), Arc::clone(&q));
(w, q)
}
#[test]
fn waker_clone_increments_refcount() {
let (task, _jh) = Task::new(async { 1u32 });
task.header.state.store(STATE_IDLE, Ordering::Release);
let q = Arc::new(GlobalQueue::new());
let w1 = make_waker(Arc::clone(&task.header), Arc::clone(&q));
let w2 = w1.clone();
drop(w1);
drop(w2);
}
#[test]
fn wake_by_ref_schedules_idle_task() {
let (task, _jh) = Task::new(async { 2u32 });
task.header.state.store(STATE_IDLE, Ordering::Release);
let (waker, queue) = make_test_waker(&task);
waker.wake_by_ref();
assert_eq!(task.header.state.load(Ordering::Acquire), STATE_SCHEDULED);
assert!(queue.pop().is_some());
}
#[test]
fn wake_consumes_and_schedules() {
let (task, _jh) = Task::new(async { 3u32 });
task.header.state.store(STATE_IDLE, Ordering::Release);
let (waker, queue) = make_test_waker(&task);
waker.wake(); assert_eq!(task.header.state.load(Ordering::Acquire), STATE_SCHEDULED);
assert!(queue.pop().is_some());
}
#[test]
fn wake_noop_when_already_scheduled() {
let (task, _jh) = Task::new(async { 4u32 });
task.header.state.store(STATE_SCHEDULED, Ordering::Release);
let (waker, queue) = make_test_waker(&task);
waker.wake_by_ref();
assert_eq!(task.header.state.load(Ordering::Acquire), STATE_SCHEDULED);
assert!(queue.pop().is_none());
}
}