use std::mem::size_of;
use std::sync::atomic::{AtomicU8, Ordering};
use std::task;
use crossbeam_channel::Sender;
use log::{error, trace};
use crate::thread_waker::ThreadWaker;
use crate::{ptr_as_usize, ProcessId};
pub const MAX_THREADS: usize = 128;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[repr(transparent)]
pub(crate) struct WakerId(u8);
pub(crate) fn init(waker: mio::Waker, notifications: Sender<ProcessId>) -> WakerId {
static THREAD_IDS: AtomicU8 = AtomicU8::new(0);
let thread_id = THREAD_IDS.fetch_add(1, Ordering::SeqCst);
assert!(
(thread_id as usize) < MAX_THREADS,
"Created too many Heph worker threads"
);
unsafe {
THREAD_WAKERS[thread_id as usize] = Some(Waker {
notifications,
thread_waker: ThreadWaker::new(waker),
});
}
WakerId(thread_id)
}
pub(crate) fn new(waker_id: WakerId, pid: ProcessId) -> task::Waker {
let data = WakerData::new(waker_id, pid).into_raw_data();
let raw_waker = task::RawWaker::new(data, &WAKER_VTABLE);
unsafe { task::Waker::from_raw(raw_waker) }
}
pub(crate) fn mark_polling(waker_id: WakerId, polling: bool) {
get(waker_id).thread_waker.mark_polling(polling);
}
static mut THREAD_WAKERS: [Option<Waker>; MAX_THREADS] = [NO_WAKER; MAX_THREADS];
const NO_WAKER: Option<Waker> = None;
fn get(waker_id: WakerId) -> &'static Waker {
unsafe {
THREAD_WAKERS[waker_id.0 as usize]
.as_ref()
.expect("tried to get a waker for a thread that isn't initialised")
}
}
pub(crate) fn get_thread_waker(waker_id: WakerId) -> &'static ThreadWaker {
&get(waker_id).thread_waker
}
#[derive(Debug)]
struct Waker {
notifications: Sender<ProcessId>,
thread_waker: ThreadWaker,
}
impl Waker {
fn wake(&self, pid: ProcessId) {
trace!(pid = pid.0; "waking process");
if let Err(err) = self.notifications.try_send(pid) {
error!("unable to send wake up notification: {}", err);
return;
}
self.wake_thread()
}
fn wake_thread(&self) {
if let Err(err) = self.thread_waker.wake() {
error!("unable to wake up worker thread: {}", err);
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[repr(transparent)]
struct WakerData(usize);
const THREAD_BITS: usize = 8;
const THREAD_SHIFT: usize = (size_of::<*const ()>() * 8) - THREAD_BITS;
const THREAD_MASK: usize = ((1 << THREAD_BITS) - 1) << THREAD_SHIFT;
impl WakerData {
fn new(thread_id: WakerId, pid: ProcessId) -> WakerData {
debug_assert!(pid.0 < (1 << THREAD_SHIFT), "pid too large");
WakerData((thread_id.0 as usize) << THREAD_SHIFT | pid.0)
}
const fn waker_id(self) -> WakerId {
#[allow(clippy::cast_possible_truncation)]
WakerId((self.0 >> THREAD_SHIFT) as u8)
}
const fn pid(self) -> ProcessId {
#[allow(clippy::cast_possible_truncation)]
ProcessId(self.0 & !THREAD_MASK)
}
const unsafe fn from_raw_data(data: *const ()) -> WakerData {
WakerData(ptr_as_usize(data))
}
const fn into_raw_data(self) -> *const () {
self.0 as *const ()
}
}
static WAKER_VTABLE: task::RawWakerVTable =
task::RawWakerVTable::new(clone_wake_data, wake, wake_by_ref, drop_wake_data);
fn assert_copy<T: Copy>() {}
unsafe fn clone_wake_data(data: *const ()) -> task::RawWaker {
assert_copy::<WakerData>();
task::RawWaker::new(data, &WAKER_VTABLE)
}
unsafe fn wake(data: *const ()) {
let data = WakerData::from_raw_data(data);
get(data.waker_id()).wake(data.pid())
}
unsafe fn wake_by_ref(data: *const ()) {
assert_copy::<WakerData>();
wake(data)
}
unsafe fn drop_wake_data(data: *const ()) {
assert_copy::<WakerData>();
#[allow(clippy::drop_copy)]
drop(data)
}
#[cfg(test)]
mod tests {
use std::mem::size_of;
use std::thread;
use std::time::Duration;
use mio::{Events, Poll, Token, Waker};
use crate::local::waker::{self as waker, WakerData, MAX_THREADS, THREAD_BITS, THREAD_MASK};
use crate::ProcessId;
const WAKER: Token = Token(0);
const PID1: ProcessId = ProcessId(0);
const PID2: ProcessId = ProcessId(1);
#[test]
fn assert_waker_data_size() {
assert_eq!(size_of::<*const ()>(), size_of::<WakerData>());
}
#[test]
fn thread_bits_large_enough() {
assert!(
2_usize.pow(THREAD_BITS as u32) >= MAX_THREADS,
"Not enough bits for MAX_THREADS"
);
}
#[test]
fn thread_mask() {
assert!(
(usize::MAX & !THREAD_MASK).leading_zeros() as usize == THREAD_BITS,
"Incorrect THREAD_MASK"
);
}
#[test]
fn waker() {
let mut poll = Poll::new().unwrap();
let mut events = Events::with_capacity(8);
let waker = Waker::new(poll.registry(), WAKER).unwrap();
let (wake_sender, wake_receiver) = crossbeam_channel::unbounded();
let waker_id = waker::init(waker, wake_sender);
let waker = waker::new(waker_id, PID1);
waker::mark_polling(waker_id, true);
waker.wake();
poll.poll(&mut events, Some(Duration::from_secs(1)))
.unwrap();
expect_one_waker_event(&mut events);
assert_eq!(wake_receiver.try_recv(), Ok(PID1));
waker::mark_polling(waker_id, false);
let pid2 = ProcessId(usize::MAX & !THREAD_MASK);
let waker = waker::new(waker_id, pid2);
waker::mark_polling(waker_id, true);
waker.wake();
poll.poll(&mut events, Some(Duration::from_secs(1)))
.unwrap();
expect_one_waker_event(&mut events);
assert_eq!(wake_receiver.try_recv(), Ok(pid2));
}
#[test]
fn waker_not_polling() {
let mut poll = Poll::new().unwrap();
let mut events = Events::with_capacity(8);
let waker = Waker::new(poll.registry(), WAKER).unwrap();
let (wake_sender, wake_receiver) = crossbeam_channel::unbounded();
let waker_id = waker::init(waker, wake_sender);
let waker = waker::new(waker_id, PID1);
waker::mark_polling(waker_id, false);
waker.wake();
poll.poll(&mut events, Some(Duration::from_millis(100)))
.unwrap();
assert!(events.is_empty());
assert_eq!(wake_receiver.try_recv(), Ok(PID1));
}
#[test]
fn waker_single_mio_waker_call() {
let mut poll = Poll::new().unwrap();
let mut events = Events::with_capacity(8);
let waker = Waker::new(poll.registry(), WAKER).unwrap();
let (wake_sender, wake_receiver) = crossbeam_channel::unbounded();
let waker_id = waker::init(waker, wake_sender);
let waker = waker::new(waker_id, PID1);
let waker2 = waker::new(waker_id, PID2);
waker::mark_polling(waker_id, true);
waker.wake_by_ref();
waker.wake(); waker2.wake_by_ref();
waker2.wake();
poll.poll(&mut events, Some(Duration::from_secs(1)))
.unwrap();
expect_one_waker_event(&mut events);
assert_eq!(wake_receiver.try_recv(), Ok(PID1));
assert_eq!(wake_receiver.try_recv(), Ok(PID1));
assert_eq!(wake_receiver.try_recv(), Ok(PID2));
assert_eq!(wake_receiver.try_recv(), Ok(PID2));
}
#[test]
fn waker_different_thread() {
let mut poll = Poll::new().unwrap();
let mut events = Events::with_capacity(8);
let waker = Waker::new(poll.registry(), WAKER).unwrap();
let (wake_sender, wake_receiver) = crossbeam_channel::unbounded();
let waker_id = waker::init(waker, wake_sender);
waker::mark_polling(waker_id, true);
let waker = waker::new(waker_id, PID1);
let handle = thread::spawn(move || {
waker.wake();
});
handle.join().unwrap();
poll.poll(&mut events, Some(Duration::from_secs(1)))
.unwrap();
expect_one_waker_event(&mut events);
assert_eq!(wake_receiver.try_recv(), Ok(PID1));
}
fn expect_one_waker_event(events: &mut Events) {
assert!(!events.is_empty());
let mut iter = events.iter();
let event = iter.next().unwrap();
assert_eq!(event.token(), WAKER);
assert!(event.is_readable());
assert!(iter.next().is_none(), "unexpected event");
}
}