use std::sync::Arc;
use std::task::RawWaker;
use super::{RawTask, TaskId, gen_task_id, queue::LocalQueue};
pub(crate) struct WakeChannel {
read_fd: std::os::fd::RawFd,
write_fd: std::os::fd::RawFd,
}
impl WakeChannel {
pub(crate) fn new() -> std::io::Result<Self> {
#[cfg(target_os = "linux")]
{
let fd = unsafe { libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK) };
if fd < 0 {
return Err(std::io::Error::last_os_error());
}
Ok(Self {
read_fd: fd,
write_fd: fd,
})
}
#[cfg(not(target_os = "linux"))]
{
let mut fds = [-1i32; 2];
let result = unsafe { libc::pipe(fds.as_mut_ptr()) };
if result < 0 {
return Err(std::io::Error::last_os_error());
}
unsafe {
libc::fcntl(fds[0], libc::F_SETFD, libc::FD_CLOEXEC);
libc::fcntl(fds[1], libc::F_SETFD, libc::FD_CLOEXEC);
libc::fcntl(fds[0], libc::F_SETFL, libc::O_NONBLOCK);
libc::fcntl(fds[1], libc::F_SETFL, libc::O_NONBLOCK);
}
Ok(Self {
read_fd: fds[0],
write_fd: fds[1],
})
}
}
pub(crate) fn notify(&self) {
#[cfg(target_os = "linux")]
unsafe {
let val: u64 = 1;
libc::write(self.write_fd, &val as *const _ as *const _, 8);
}
#[cfg(not(target_os = "linux"))]
unsafe {
let val: u8 = 1;
libc::write(self.write_fd, &val as *const _ as *const _, 1);
}
}
pub(crate) fn drain(&self) {
#[cfg(target_os = "linux")]
unsafe {
let mut val: u64 = 0;
while libc::read(self.read_fd, &mut val as *mut _ as *mut _, 8) == 8 {
}
}
#[cfg(not(target_os = "linux"))]
unsafe {
let mut val: u8 = 0;
while libc::read(self.read_fd, &mut val as *mut _ as *mut _, 1) == 1 {
}
}
}
pub(crate) fn recv_timeout(&self, timeout: std::time::Duration) -> bool {
let mut tv = libc::timeval {
tv_sec: timeout.as_secs() as _,
tv_usec: timeout.subsec_micros() as _,
};
unsafe {
let mut fdset: libc::fd_set = std::mem::zeroed();
libc::FD_ZERO(&mut fdset);
libc::FD_SET(self.read_fd, &mut fdset);
let n = libc::select(
self.read_fd + 1,
&mut fdset,
std::ptr::null_mut(),
std::ptr::null_mut(),
&mut tv,
);
if n > 0 {
self.drain();
true
} else {
false
}
}
}
#[must_use]
pub(crate) fn raw_fd(&self) -> std::os::fd::RawFd {
self.read_fd
}
}
impl Drop for WakeChannel {
fn drop(&mut self) {
#[cfg(target_os = "linux")]
{
if self.read_fd >= 0 {
unsafe {
libc::close(self.read_fd);
}
}
}
#[cfg(not(target_os = "linux"))]
{
if self.read_fd >= 0 {
unsafe {
libc::close(self.read_fd);
}
}
if self.write_fd >= 0 {
unsafe {
libc::close(self.write_fd);
}
}
}
}
}
#[derive(Clone)]
pub struct SchedulerHandle {
queue: Arc<LocalQueue>,
wake: Arc<WakeChannel>,
}
impl SchedulerHandle {
pub(crate) fn new(queue: Arc<LocalQueue>, wake: Arc<WakeChannel>) -> Self {
Self { queue, wake }
}
pub fn new_default() -> Self {
Self {
queue: Arc::new(LocalQueue::new(256)),
wake: Arc::new(WakeChannel::new().unwrap()),
}
}
pub fn submit(&self, task: RawTask) -> std::io::Result<()> {
if self.queue.push(task) {
self.wake.notify();
Ok(())
} else {
Err(std::io::Error::new(std::io::ErrorKind::WouldBlock, "Scheduler queue is full"))
}
}
pub fn submit_with_id(&self, _task_id: TaskId, task: RawTask) -> std::io::Result<()> {
self.submit(task)
}
#[must_use]
pub fn wake_fd(&self) -> std::os::fd::RawFd {
self.wake.raw_fd()
}
pub fn handle_wake(&self) {
self.wake.drain();
}
#[must_use]
pub fn new_task_id(&self) -> TaskId {
gen_task_id()
}
pub fn waker(&self) -> std::task::Waker {
let handle_clone = self.clone();
let raw_waker = RawWaker::new(Arc::into_raw(Arc::new(handle_clone)) as *const (), &VTABLE);
unsafe { std::task::Waker::from_raw(raw_waker) }
}
pub fn get_task_waker(&self, _id: u64) -> Option<std::task::Waker> {
None
}
}
static VTABLE: std::task::RawWakerVTable =
std::task::RawWakerVTable::new(clone_waker, wake, wake_by_ref, drop_waker);
unsafe fn clone_waker(data: *const ()) -> RawWaker {
let handle = Arc::from_raw(data as *const SchedulerHandle);
let ptr = Arc::into_raw(handle.clone()) as *const ();
RawWaker::new(ptr, &VTABLE)
}
unsafe fn wake(data: *const ()) {
let handle = Arc::from_raw(data as *const SchedulerHandle);
handle.wake.notify();
}
unsafe fn wake_by_ref(data: *const ()) {
let handle = &*(data as *const SchedulerHandle);
handle.wake.notify();
}
unsafe fn drop_waker(data: *const ()) {
let _ = Arc::from_raw(data as *const SchedulerHandle);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handle_submit() {
let queue = Arc::new(LocalQueue::new(16));
let wake = Arc::new(WakeChannel::new().unwrap());
let handle = SchedulerHandle::new(queue.clone(), wake);
let task = 0x1000 as RawTask;
assert!(handle.submit(task).is_ok());
assert_eq!(queue.pop(), Some(task));
}
#[test]
fn test_wake_channel_notify_and_drain() {
let wake = WakeChannel::new().unwrap();
assert!(wake.raw_fd() >= 0);
let start = std::time::Instant::now();
let received = wake.recv_timeout(std::time::Duration::from_millis(5));
assert!(!received, "empty channel should not receive");
assert!(start.elapsed() >= std::time::Duration::from_millis(3));
wake.notify();
wake.drain();
let received = wake.recv_timeout(std::time::Duration::from_millis(5));
assert!(!received, "drained notification should not be received again");
}
#[test]
fn test_wake_channel_multiple_notify() {
let wake = WakeChannel::new().unwrap();
wake.notify();
wake.notify();
wake.notify();
let received = wake.recv_timeout(std::time::Duration::from_millis(10));
assert!(received, "should receive after notify");
wake.drain();
}
#[test]
fn test_recv_timeout_no_notification() {
let wake = WakeChannel::new().unwrap();
let start = std::time::Instant::now();
let received = wake.recv_timeout(std::time::Duration::from_millis(10));
assert!(!received);
assert!(start.elapsed() >= std::time::Duration::from_millis(5));
}
#[test]
fn test_recv_timeout_with_notification() {
let wake = WakeChannel::new().unwrap();
wake.notify();
let received = wake.recv_timeout(std::time::Duration::from_secs(1));
assert!(received);
}
}