use crate::{IoBuf, IoBufMut, IoBufs};
use commonware_utils::channel::{
mpsc::{self, error::TryRecvError},
oneshot,
};
use io_uring::{
cqueue::Entry as CqueueEntry,
opcode::{LinkTimeout, PollAdd},
squeue::{Entry as SqueueEntry, SubmissionQueue},
types::{Fd, SubmitArgs, Timespec},
IoUring,
};
use prometheus_client::{metrics::gauge::Gauge, registry::Registry};
use std::{
fs::File,
mem::size_of,
os::fd::{AsRawFd, FromRawFd, OwnedFd},
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};
use tracing::warn;
const TIMEOUT_USER_DATA: u64 = u64::MAX;
const WAKE_USER_DATA: u64 = u64::MAX - 1;
const SLEEP_INTENT_BIT: u64 = 1;
const SUBMISSION_INCREMENT: u64 = 2;
const SUBMISSION_SEQ_MASK: u64 = u64::MAX >> 1;
#[derive(Debug)]
pub enum OpBuffer {
Read(IoBufMut),
Write(IoBuf),
WriteVectored(IoBufs),
}
impl From<IoBufMut> for OpBuffer {
fn from(buf: IoBufMut) -> Self {
Self::Read(buf)
}
}
impl From<IoBuf> for OpBuffer {
fn from(buf: IoBuf) -> Self {
Self::Write(buf)
}
}
impl From<IoBufs> for OpBuffer {
fn from(bufs: IoBufs) -> Self {
Self::WriteVectored(bufs)
}
}
pub enum OpFd {
#[cfg_attr(not(feature = "iouring-network"), allow(dead_code))]
Fd(#[allow(dead_code)] Arc<OwnedFd>),
#[cfg_attr(not(feature = "iouring-storage"), allow(dead_code))]
File(#[allow(dead_code)] Arc<File>),
}
pub struct OpIovecs(#[allow(dead_code)] Box<[libc::iovec]>);
impl OpIovecs {
pub const fn new(iovecs: Box<[libc::iovec]>) -> Self {
Self(iovecs)
}
pub fn as_ptr(&self) -> *const libc::iovec {
self.0.as_ptr()
}
}
unsafe impl Send for OpIovecs {}
#[derive(Debug)]
pub struct Metrics {
pending_operations: Gauge,
}
impl Metrics {
pub fn new(registry: &mut Registry) -> Self {
let metrics = Self {
pending_operations: Gauge::default(),
};
registry.register(
"pending_operations",
"Number of operations submitted to the io_uring whose CQEs haven't yet been processed",
metrics.pending_operations.clone(),
);
metrics
}
}
#[derive(Clone, Debug)]
pub struct Config {
pub size: u32,
pub io_poll: bool,
pub single_issuer: bool,
pub op_timeout: Option<Duration>,
pub shutdown_timeout: Option<Duration>,
}
impl Default for Config {
fn default() -> Self {
Self {
size: 128,
io_poll: false,
single_issuer: false,
op_timeout: None,
shutdown_timeout: None,
}
}
}
pub struct Op {
pub work: SqueueEntry,
pub sender: oneshot::Sender<(i32, Option<OpBuffer>)>,
pub buffer: Option<OpBuffer>,
pub fd: Option<OpFd>,
pub iovecs: Option<OpIovecs>,
}
struct WakerInner {
wake_fd: OwnedFd,
state: AtomicU64,
}
#[derive(Clone)]
struct Waker {
inner: Arc<WakerInner>,
}
impl Waker {
fn new() -> Result<Self, std::io::Error> {
let fd = unsafe { libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK) };
if fd < 0 {
return Err(std::io::Error::last_os_error());
}
let wake_fd = unsafe { OwnedFd::from_raw_fd(fd) };
Ok(Self {
inner: Arc::new(WakerInner {
wake_fd,
state: AtomicU64::new(0),
}),
})
}
fn ring(&self) {
let value: u64 = 1;
loop {
let ret = unsafe {
libc::write(
self.inner.wake_fd.as_raw_fd(),
&value as *const u64 as *const libc::c_void,
size_of::<u64>(),
)
};
if ret == size_of::<u64>() as isize {
return;
}
if ret == -1 {
match std::io::Error::last_os_error().raw_os_error() {
Some(libc::EINTR) => continue,
Some(libc::EAGAIN) => return,
_ => {
warn!("eventfd write failed");
return;
}
}
}
return;
}
}
fn publish(&self) {
let prev = self
.inner
.state
.fetch_add(SUBMISSION_INCREMENT, Ordering::Release);
if (prev & SLEEP_INTENT_BIT) != 0 {
self.ring();
}
}
fn submitted(&self) -> u64 {
(self.inner.state.load(Ordering::Acquire) >> 1) & SUBMISSION_SEQ_MASK
}
fn arm(&self) -> u64 {
let prev = self
.inner
.state
.fetch_or(SLEEP_INTENT_BIT, Ordering::Acquire);
(prev >> 1) & SUBMISSION_SEQ_MASK
}
fn disarm(&self) {
self.inner
.state
.fetch_and(!SLEEP_INTENT_BIT, Ordering::Release);
}
fn acknowledge(&self) {
let mut value: u64 = 0;
loop {
let ret = unsafe {
libc::read(
self.inner.wake_fd.as_raw_fd(),
&mut value as *mut u64 as *mut libc::c_void,
size_of::<u64>(),
)
};
if ret == size_of::<u64>() as isize {
return;
}
if ret == -1 {
match std::io::Error::last_os_error().raw_os_error() {
Some(libc::EINTR) => continue,
Some(libc::EAGAIN) => return,
_ => {
tracing::warn!("eventfd read failed");
return;
}
}
}
return;
}
}
fn reinstall(&self, submission_queue: &mut SubmissionQueue<'_>) {
let wake_poll = PollAdd::new(Fd(self.inner.wake_fd.as_raw_fd()), libc::POLLIN as u32)
.multi(true)
.build()
.user_data(WAKE_USER_DATA);
unsafe {
submission_queue
.push(&wake_poll)
.expect("wake poll SQE should always fit in the ring");
}
}
}
struct SubmitterInner {
sender: Option<mpsc::Sender<Op>>,
waker: Waker,
}
impl Drop for SubmitterInner {
fn drop(&mut self) {
drop(self.sender.take());
self.waker.ring();
}
}
#[derive(Clone)]
pub struct Submitter {
inner: Arc<SubmitterInner>,
}
impl Submitter {
pub async fn send(&self, op: Op) -> Result<(), mpsc::error::SendError<Op>> {
self.inner
.sender
.as_ref()
.expect("submitter sender is only taken on drop")
.send(op)
.await?;
self.inner.waker.publish();
Ok(())
}
}
struct Waiter {
sender: oneshot::Sender<(i32, Option<OpBuffer>)>,
buffer: Option<OpBuffer>,
#[allow(dead_code)]
fd: Option<OpFd>,
#[allow(dead_code)]
iovecs: Option<OpIovecs>,
timespec: Option<Timespec>,
}
struct Waiters {
entries: Vec<Option<Waiter>>,
free: Vec<usize>,
len: usize,
}
impl Waiters {
fn new(capacity: usize) -> Self {
let mut entries = Vec::with_capacity(capacity);
entries.resize_with(capacity, || None);
let mut free = Vec::with_capacity(capacity);
free.extend((0..capacity).rev());
Self {
entries,
free,
len: 0,
}
}
const fn len(&self) -> usize {
self.len
}
const fn is_empty(&self) -> bool {
self.len == 0
}
fn get(&self, slot_index: u64) -> &Waiter {
let index = usize::try_from(slot_index).expect("slot index should fit in usize");
let slot = self.entries.get(index).expect("missing waiter");
slot.as_ref().expect("missing waiter")
}
fn insert(&mut self, waiter: Waiter) -> u64 {
let index = self
.free
.pop()
.expect("waiters should not exceed configured capacity");
let replaced = self.entries[index].replace(waiter);
assert!(replaced.is_none(), "free slot should not contain waiter");
self.len += 1;
index as u64
}
fn remove(&mut self, slot_index: u64) -> Waiter {
let index = usize::try_from(slot_index).expect("slot index should fit in usize");
let slot = self.entries.get_mut(index).expect("missing waiter");
let waiter = slot.take().expect("missing waiter");
self.free.push(index);
self.len -= 1;
waiter
}
}
pub(crate) struct IoUringLoop {
cfg: Config,
metrics: Arc<Metrics>,
receiver: mpsc::Receiver<Op>,
waiters: Waiters,
waker: Waker,
wake_rearm_needed: bool,
processed_seq: u64,
}
impl IoUringLoop {
pub(crate) fn new(mut cfg: Config, registry: &mut Registry) -> (Submitter, Self) {
cfg.size = cfg
.size
.checked_next_power_of_two()
.expect("ring size exceeds u32::MAX");
let size = cfg.size as usize;
let metrics = Arc::new(Metrics::new(registry));
let (sender, receiver) = mpsc::channel(size);
let waker = Waker::new().expect("unable to create wake eventfd");
let submitter = Submitter {
inner: Arc::new(SubmitterInner {
sender: Some(sender),
waker: waker.clone(),
}),
};
(
submitter,
Self {
cfg,
metrics,
receiver,
waiters: Waiters::new(size),
waker,
wake_rearm_needed: true,
processed_seq: 0,
},
)
}
pub(crate) fn run(mut self) {
let mut ring = new_ring(&self.cfg).expect("unable to create io_uring instance");
loop {
for cqe in ring.completion() {
self.handle_cqe(cqe);
}
let Some(at_capacity) = self.fill_submission_queue(&mut ring) else {
self.drain(&mut ring);
return;
};
self.metrics.pending_operations.set(self.waiters.len() as _);
if self.waker.submitted() != self.processed_seq {
if at_capacity {
self.submit_and_wait(&mut ring, 1, None)
.expect("unable to submit to ring");
}
continue;
}
if self.waker.arm() == self.processed_seq {
self.submit_and_wait(&mut ring, 1, None)
.expect("unable to submit to ring");
}
self.waker.disarm();
}
}
fn fill_submission_queue(&mut self, ring: &mut IoUring) -> Option<bool> {
let mut drained = 0u64;
let mut submission_queue = ring.submission();
let mut at_sq_capacity = false;
if std::mem::take(&mut self.wake_rearm_needed) {
self.waker.reinstall(&mut submission_queue);
}
while self.waiters.len() < self.cfg.size as usize {
let available = submission_queue.capacity() - submission_queue.len();
let needed = if self.cfg.op_timeout.is_some() { 2 } else { 1 };
if available < needed {
at_sq_capacity = true;
break;
}
let op = match self.receiver.try_recv() {
Ok(work) => work,
Err(TryRecvError::Disconnected) => return None,
Err(TryRecvError::Empty) => break,
};
drained += 1;
let Op {
mut work,
sender,
buffer,
fd,
iovecs,
} = op;
let timespec = self.cfg.op_timeout.map(|timeout| {
Timespec::new()
.sec(timeout.as_secs())
.nsec(timeout.subsec_nanos())
});
let slot_index = self.waiters.insert(Waiter {
sender,
buffer,
fd,
iovecs,
timespec,
});
work = work.user_data(slot_index);
if self.cfg.op_timeout.is_some() {
work = work.flags(io_uring::squeue::Flags::IO_LINK);
}
unsafe {
submission_queue
.push(&work)
.expect("unable to push to queue");
}
if self.cfg.op_timeout.is_some() {
let timeout = LinkTimeout::new(
self.waiters
.get(slot_index)
.timespec
.as_ref()
.expect("missing timespec"),
)
.build()
.user_data(TIMEOUT_USER_DATA);
unsafe {
submission_queue
.push(&timeout)
.expect("unable to push timeout to queue");
}
}
}
self.processed_seq = self.processed_seq.wrapping_add(drained) & SUBMISSION_SEQ_MASK;
let at_waiter_capacity = self.waiters.len() == self.cfg.size as usize;
Some(at_waiter_capacity || at_sq_capacity)
}
fn handle_cqe(&mut self, cqe: CqueueEntry) {
let user_data = cqe.user_data();
match user_data {
WAKE_USER_DATA => {
assert!(
cqe.result() >= 0,
"wake poll CQE failed: requires multishot poll (Linux 5.13+)"
);
self.waker.acknowledge();
if !io_uring::cqueue::more(cqe.flags()) {
self.wake_rearm_needed = true;
}
}
TIMEOUT_USER_DATA => {
assert!(
self.cfg.op_timeout.is_some(),
"received TIMEOUT_USER_DATA with op_timeout disabled"
);
}
_ => {
let result = cqe.result();
let result = if result == -libc::ECANCELED && self.cfg.op_timeout.is_some() {
-libc::ETIMEDOUT
} else {
result
};
let Waiter {
sender: result_sender,
buffer,
..
} = self.waiters.remove(user_data);
let _ = result_sender.send((result, buffer));
}
}
}
fn drain(&mut self, ring: &mut IoUring) {
let mut remaining = self.cfg.shutdown_timeout;
while !self.waiters.is_empty() {
if remaining.is_some_and(|t| t.is_zero()) {
break;
}
let start = Instant::now();
let got_completion = self
.submit_and_wait(ring, 1, remaining)
.expect("unable to submit to ring");
for cqe in ring.completion() {
self.handle_cqe(cqe);
}
if !got_completion {
break;
}
if let Some(remaining) = remaining.as_mut() {
*remaining = remaining.saturating_sub(start.elapsed());
}
}
self.metrics.pending_operations.set(self.waiters.len() as _);
}
fn submit_and_wait(
&self,
ring: &mut IoUring,
want: usize,
timeout: Option<Duration>,
) -> Result<bool, std::io::Error> {
let result = timeout.map_or_else(
|| ring.submit_and_wait(want).map(|_| true),
|timeout| {
let ts = Timespec::new()
.sec(timeout.as_secs())
.nsec(timeout.subsec_nanos());
let args = SubmitArgs::new().timespec(&ts);
match ring.submitter().submit_with_args(want, &args) {
Ok(_) => Ok(true),
Err(err) if err.raw_os_error() == Some(libc::ETIME) => Ok(false),
Err(err) => Err(err),
}
},
);
match result {
Ok(v) => Ok(v),
Err(err) => match err.raw_os_error() {
Some(libc::EINTR | libc::EAGAIN | libc::EBUSY) => Ok(true),
_ => Err(err),
},
}
}
}
fn new_ring(cfg: &Config) -> Result<IoUring, std::io::Error> {
let mut builder = &mut IoUring::builder();
if cfg.io_poll {
builder = builder.setup_iopoll();
}
if cfg.single_issuer {
builder = builder.setup_single_issuer();
builder = builder.setup_defer_taskrun();
}
let ring_size = if cfg.op_timeout.is_some() {
cfg.size.checked_mul(2).expect("ring size overflow")
} else {
cfg.size
};
builder.build(ring_size)
}
pub const fn should_retry(return_value: i32) -> bool {
return_value == -libc::EAGAIN
|| return_value == -libc::EWOULDBLOCK
|| return_value == -libc::EINTR
}
#[cfg(test)]
mod tests {
use super::*;
use commonware_utils::channel::oneshot::{self, error::RecvError};
use io_uring::{
opcode,
types::{Fd, Timespec},
};
use prometheus_client::registry::Registry;
use std::{
os::{fd::AsRawFd, unix::net::UnixStream},
time::Duration,
};
#[test]
fn test_iouring_loop_rounds_ring_size_up_to_power_of_two() {
let mut registry = Registry::default();
let cfg = Config {
size: 1_000,
..Default::default()
};
let (_, iouring) = IoUringLoop::new(cfg, &mut registry);
assert_eq!(iouring.cfg.size, 1_024);
assert_eq!(iouring.waiters.entries.len(), 1_024);
let cfg = Config {
size: 1_024,
..Default::default()
};
let (_, iouring) = IoUringLoop::new(cfg, &mut registry);
assert_eq!(iouring.cfg.size, 1_024);
assert_eq!(iouring.waiters.entries.len(), 1_024);
}
#[test]
fn test_waiters() {
let mut waiters = Waiters::new(3);
assert_eq!(waiters.len(), 0);
assert!(waiters.is_empty());
let (tx0, _rx0) = oneshot::channel();
let (tx1, _rx1) = oneshot::channel();
let index0 = waiters.insert(Waiter {
sender: tx0,
buffer: Some(IoBuf::from(b"hello").into()),
fd: None,
iovecs: None,
timespec: None,
});
let index1 = waiters.insert(Waiter {
sender: tx1,
buffer: Some(IoBuf::from(b"world").into()),
fd: None,
iovecs: None,
timespec: None,
});
assert_eq!((index0, index1), (0, 1));
assert_eq!(waiters.len(), 2);
assert!(!waiters.is_empty());
match waiters.get(index0).buffer.as_ref() {
Some(OpBuffer::Write(buf)) => assert_eq!(buf.as_ref(), b"hello"),
_ => panic!("expected write buffer"),
}
let waiter = waiters.remove(index1);
match waiter.buffer {
Some(OpBuffer::Write(buf)) => assert_eq!(buf.as_ref(), b"world"),
_ => panic!("expected write buffer"),
}
assert_eq!(waiters.len(), 1);
let (tx2, _rx2) = oneshot::channel();
let index2 = waiters.insert(Waiter {
sender: tx2,
buffer: None,
fd: None,
iovecs: None,
timespec: None,
});
assert_eq!(index2, index1);
waiters.remove(index0);
waiters.remove(index2);
assert!(waiters.is_empty());
}
#[test]
#[should_panic(expected = "missing waiter")]
fn test_waiters_remove_missing_slot_panics() {
let mut waiters = Waiters::new(1);
let _ = waiters.remove(0u64);
}
#[test]
#[should_panic(expected = "waiters should not exceed configured capacity")]
fn test_waiters_insert_full_panics() {
let mut waiters = Waiters::new(1);
let (tx0, _rx0) = oneshot::channel();
let (tx1, _rx1) = oneshot::channel();
let _ = waiters.insert(Waiter {
sender: tx0,
buffer: None,
fd: None,
iovecs: None,
timespec: None,
});
let _ = waiters.insert(Waiter {
sender: tx1,
buffer: None,
fd: None,
iovecs: None,
timespec: None,
});
}
async fn recv_then_send(cfg: Config, should_succeed: bool) {
let mut registry = Registry::default();
let (submitter, iouring) = IoUringLoop::new(cfg, &mut registry);
let handle = std::thread::spawn(move || iouring.run());
let (left_pipe, right_pipe) = UnixStream::pair().unwrap();
let msg = IoBuf::from(b"hello");
let mut buf = IoBufMut::with_capacity(msg.len());
let recv =
opcode::Recv::new(Fd(left_pipe.as_raw_fd()), buf.as_mut_ptr(), msg.len() as _).build();
let (recv_tx, recv_rx) = oneshot::channel();
submitter
.send(Op {
work: recv,
sender: recv_tx,
buffer: Some(buf.into()),
fd: None,
iovecs: None,
})
.await
.expect("failed to send work");
let write =
opcode::Write::new(Fd(right_pipe.as_raw_fd()), msg.as_ptr(), msg.len() as _).build();
let (write_tx, write_rx) = oneshot::channel();
submitter
.send(Op {
work: write,
sender: write_tx,
buffer: Some(msg.into()),
fd: None,
iovecs: None,
})
.await
.expect("failed to send work");
if should_succeed {
let (result, _) = recv_rx.await.expect("failed to receive result");
assert!(result > 0, "recv failed: {result}");
let (result, _) = write_rx.await.expect("failed to receive result");
assert!(result > 0, "write failed: {result}");
} else {
let _ = recv_rx.await;
let _ = write_rx.await;
}
drop(submitter);
handle.join().unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_wake_path_makes_progress() {
let timeout = tokio::time::timeout(
Duration::from_secs(2),
recv_then_send(Default::default(), true),
);
assert!(
timeout.await.is_ok(),
"recv_then_send timed out unexpectedly"
);
}
#[tokio::test]
async fn test_timeout() {
let cfg = Config {
op_timeout: Some(std::time::Duration::from_secs(1)),
..Default::default()
};
let mut registry = Registry::default();
let (submitter, iouring) = IoUringLoop::new(cfg, &mut registry);
let handle = std::thread::spawn(move || iouring.run());
let (pipe_left, _pipe_right) = UnixStream::pair().unwrap();
let mut buf = IoBufMut::with_capacity(8);
let work = opcode::Recv::new(
Fd(pipe_left.as_raw_fd()),
buf.as_mut_ptr(),
buf.capacity() as _,
)
.build();
let (tx, rx) = oneshot::channel();
submitter
.send(Op {
work,
sender: tx,
buffer: Some(buf.into()),
fd: None,
iovecs: None,
})
.await
.expect("failed to send work");
let (result, _) = rx.await.expect("failed to receive result");
assert_eq!(result, -libc::ETIMEDOUT);
drop(submitter);
handle.join().unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_shutdown_no_timeout() {
let cfg = Config {
shutdown_timeout: None,
..Default::default()
};
let mut registry = Registry::default();
let (submitter, iouring) = IoUringLoop::new(cfg, &mut registry);
let handle = std::thread::spawn(move || iouring.run());
let timeout = Timespec::new().sec(3);
let timeout = opcode::Timeout::new(&timeout).build();
let (tx, rx) = oneshot::channel();
submitter
.send(Op {
work: timeout,
sender: tx,
buffer: None,
fd: None,
iovecs: None,
})
.await
.unwrap();
drop(submitter);
let (result, _) = rx.await.unwrap();
assert_eq!(result, -libc::ETIME);
handle.join().unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_shutdown_timeout() {
let cfg = Config {
shutdown_timeout: Some(Duration::from_secs(1)),
..Default::default()
};
let mut registry = Registry::default();
let (submitter, iouring) = IoUringLoop::new(cfg, &mut registry);
let handle = std::thread::spawn(move || iouring.run());
let timeout = Timespec::new().sec(5_000);
let timeout = opcode::Timeout::new(&timeout).build();
let (tx, rx) = oneshot::channel();
submitter
.send(Op {
work: timeout,
sender: tx,
buffer: None,
fd: None,
iovecs: None,
})
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
drop(submitter);
let err = rx.await.unwrap_err();
assert!(matches!(err, RecvError { .. }));
handle.join().unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_linked_timeout_ensure_enough_capacity() {
let cfg = Config {
size: 8,
op_timeout: Some(Duration::from_millis(5)),
..Default::default()
};
let mut registry = Registry::default();
let (submitter, iouring) = IoUringLoop::new(cfg, &mut registry);
let handle = std::thread::spawn(move || iouring.run());
let total = 64usize;
let mut rxs = Vec::with_capacity(total);
for _ in 0..total {
let nop = opcode::Nop::new().build();
let (tx, rx) = oneshot::channel();
submitter
.send(Op {
work: nop,
sender: tx,
buffer: None,
fd: None,
iovecs: None,
})
.await
.unwrap();
rxs.push(rx);
}
for rx in rxs {
let (res, _) = rx.await.unwrap();
assert_eq!(res, 0, "NOP op failed: {res}");
}
drop(submitter);
handle.join().unwrap();
}
#[tokio::test]
async fn test_single_issuer() {
let cfg = Config {
single_issuer: true,
..Default::default()
};
let mut registry = Registry::default();
let (sender, iouring) = IoUringLoop::new(cfg, &mut registry);
let uring_thread = std::thread::spawn(move || iouring.run());
let (tx, rx) = oneshot::channel();
sender
.send(Op {
work: opcode::Nop::new().build(),
sender: tx,
buffer: None,
fd: None,
iovecs: None,
})
.await
.unwrap();
let (result, _) = rx.await.unwrap();
assert_eq!(result, 0);
drop(sender);
uring_thread.join().unwrap();
}
}