use super::UserData;
use io_uring::{opcode::PollAdd, squeue::SubmissionQueue, types::Fd};
use std::{
mem::size_of,
os::fd::{AsRawFd, FromRawFd, OwnedFd},
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use tracing::warn;
pub const WAKE_USER_DATA: UserData = UserData::MAX;
const SLEEP_INTENT_BIT: u64 = 1;
const SUBMISSION_INCREMENT: u64 = 2;
pub const SUBMISSION_SEQ_MASK: u64 = u64::MAX >> 1;
struct WakerInner {
wake_fd: OwnedFd,
state: AtomicU64,
}
#[derive(Clone)]
pub struct Waker {
inner: Arc<WakerInner>,
}
impl Waker {
pub fn new() -> Result<Self, std::io::Error> {
let fd = unsafe { libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK) };
if fd < 0 {
return Err(std::io::Error::last_os_error());
}
let wake_fd = unsafe { OwnedFd::from_raw_fd(fd) };
Ok(Self {
inner: Arc::new(WakerInner {
wake_fd,
state: AtomicU64::new(0),
}),
})
}
pub fn ring(&self) {
let value: u64 = 1;
loop {
let ret = unsafe {
libc::write(
self.inner.wake_fd.as_raw_fd(),
&value as *const u64 as *const libc::c_void,
size_of::<u64>(),
)
};
if ret == size_of::<u64>() as isize {
return;
}
assert_eq!(
ret, -1,
"eventfd write returned unexpected byte count: {ret}"
);
match std::io::Error::last_os_error().raw_os_error() {
Some(libc::EINTR) => continue,
Some(libc::EAGAIN) => return,
_ => {
warn!("eventfd write failed");
return;
}
}
}
}
pub fn publish(&self) {
let prev = self
.inner
.state
.fetch_add(SUBMISSION_INCREMENT, Ordering::Release);
if (prev & SLEEP_INTENT_BIT) != 0 {
self.ring();
}
}
pub fn submitted(&self) -> u64 {
(self.inner.state.load(Ordering::Acquire) >> 1) & SUBMISSION_SEQ_MASK
}
pub fn arm(&self) -> u64 {
let prev = self
.inner
.state
.fetch_or(SLEEP_INTENT_BIT, Ordering::Acquire);
(prev >> 1) & SUBMISSION_SEQ_MASK
}
pub fn disarm(&self) {
self.inner
.state
.fetch_and(!SLEEP_INTENT_BIT, Ordering::Release);
}
pub fn acknowledge(&self) {
let mut value: u64 = 0;
loop {
let ret = unsafe {
libc::read(
self.inner.wake_fd.as_raw_fd(),
&mut value as *mut u64 as *mut libc::c_void,
size_of::<u64>(),
)
};
if ret == size_of::<u64>() as isize {
return;
}
assert_eq!(
ret, -1,
"eventfd read returned unexpected byte count: {ret}"
);
match std::io::Error::last_os_error().raw_os_error() {
Some(libc::EINTR) => continue,
Some(libc::EAGAIN) => return,
_ => {
tracing::warn!("eventfd read failed");
return;
}
}
}
}
pub fn reinstall(&self, submission_queue: &mut SubmissionQueue<'_>) {
let wake_poll = PollAdd::new(Fd(self.inner.wake_fd.as_raw_fd()), libc::POLLIN as u32)
.multi(true)
.build()
.user_data(WAKE_USER_DATA);
unsafe {
submission_queue
.push(&wake_poll)
.expect("wake poll SQE should always fit in the ring");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use io_uring::IoUring;
use std::os::fd::FromRawFd;
#[test]
fn test_publish_arm_disarm_and_submitted() {
let waker = Waker::new().expect("eventfd creation should succeed");
assert_eq!(waker.submitted(), 0);
waker.publish();
assert_eq!(waker.submitted(), 1);
let snapshot = waker.arm();
assert_eq!(snapshot, 1);
waker.publish();
assert_eq!(waker.submitted(), 2);
waker.acknowledge();
assert_eq!(waker.submitted(), 2);
waker.disarm();
assert_eq!(waker.submitted(), 2);
assert_eq!(waker.arm(), 2);
waker.disarm();
}
#[test]
fn test_ring_and_acknowledge_empty_paths_keep_sequence_stable() {
let waker = Waker::new().expect("eventfd creation should succeed");
let before = waker.submitted();
waker.ring();
waker.acknowledge();
waker.acknowledge();
let after = waker.submitted();
assert_eq!(after, before);
}
#[test]
fn test_reinstall_pushes_wake_poll() {
let waker = Waker::new().expect("eventfd creation should succeed");
let mut ring = IoUring::new(8).expect("io_uring creation should succeed");
let mut sq = ring.submission();
let before = sq.len();
waker.reinstall(&mut sq);
assert_eq!(sq.len(), before + 1);
}
#[test]
fn test_ring_and_acknowledge_error_branches() {
let mut waker = Waker::new().expect("eventfd creation should succeed");
let before = waker.submitted();
let fd = waker.inner.wake_fd.as_raw_fd();
let value = u64::MAX - 1;
let wrote = unsafe {
libc::write(
fd,
&value as *const u64 as *const libc::c_void,
size_of::<u64>(),
)
};
assert_eq!(wrote, size_of::<u64>() as isize);
waker.ring();
waker.acknowledge();
let closed = unsafe { libc::close(fd) };
assert_eq!(closed, 0);
waker.ring();
waker.acknowledge();
let replacement = unsafe { libc::dup(libc::STDIN_FILENO) };
assert!(replacement >= 0);
let old = {
let inner = std::sync::Arc::get_mut(&mut waker.inner).expect("unique waker in test");
std::mem::replace(&mut inner.wake_fd, unsafe {
std::os::fd::OwnedFd::from_raw_fd(replacement)
})
};
std::mem::forget(old);
assert_eq!(waker.submitted(), before);
}
}