use commonware_utils::StableBuf;
use futures::{
channel::{mpsc, oneshot},
StreamExt as _,
};
use io_uring::{
cqueue::Entry as CqueueEntry,
opcode::LinkTimeout,
squeue::Entry as SqueueEntry,
types::{SubmitArgs, Timespec},
IoUring,
};
use prometheus_client::{metrics::gauge::Gauge, registry::Registry};
use std::{collections::HashMap, sync::Arc, time::Duration};
const TIMEOUT_WORK_ID: u64 = u64::MAX;
#[derive(Debug)]
pub struct Metrics {
pending_operations: Gauge,
}
impl Metrics {
pub fn new(registry: &mut Registry) -> Self {
let metrics = Self {
pending_operations: Gauge::default(),
};
registry.register(
"pending_operations",
"Number of operations submitted to the io_uring whose CQEs haven't yet been processed",
metrics.pending_operations.clone(),
);
metrics
}
}
#[derive(Clone, Debug)]
pub struct Config {
pub size: u32,
pub io_poll: bool,
pub single_issuer: bool,
pub force_poll: Option<Duration>,
pub op_timeout: Option<Duration>,
pub shutdown_timeout: Option<Duration>,
}
impl Default for Config {
fn default() -> Self {
Self {
size: 128,
io_poll: false,
single_issuer: false,
force_poll: None,
op_timeout: None,
shutdown_timeout: None,
}
}
}
fn new_ring(cfg: &Config) -> Result<IoUring, std::io::Error> {
let mut builder = &mut IoUring::builder();
if cfg.io_poll {
builder = builder.setup_iopoll();
}
if cfg.single_issuer {
builder = builder.setup_single_issuer();
builder = builder.setup_defer_taskrun();
}
let ring_size = if cfg.op_timeout.is_some() {
cfg.size * 2
} else {
cfg.size
};
builder.build(ring_size)
}
pub struct Op {
pub work: SqueueEntry,
pub sender: oneshot::Sender<(i32, Option<StableBuf>)>,
pub buffer: Option<StableBuf>,
}
#[allow(clippy::type_complexity)]
fn handle_cqe(
waiters: &mut HashMap<u64, (oneshot::Sender<(i32, Option<StableBuf>)>, Option<StableBuf>)>,
cqe: CqueueEntry,
cfg: &Config,
) {
let work_id = cqe.user_data();
match work_id {
TIMEOUT_WORK_ID => {
assert!(
cfg.op_timeout.is_some(),
"received TIMEOUT_WORK_ID with op_timeout disabled"
);
}
_ => {
let result = cqe.result();
let result = if result == -libc::ECANCELED && cfg.op_timeout.is_some() {
-libc::ETIMEDOUT
} else {
result
};
let (result_sender, buffer) = waiters.remove(&work_id).expect("missing sender");
let _ = result_sender.send((result, buffer));
}
}
}
pub(crate) async fn run(cfg: Config, metrics: Arc<Metrics>, mut receiver: mpsc::Receiver<Op>) {
let mut ring = new_ring(&cfg).expect("unable to create io_uring instance");
let mut next_work_id: u64 = 0;
#[allow(clippy::type_complexity)]
let mut waiters: std::collections::HashMap<
_,
(oneshot::Sender<(i32, Option<StableBuf>)>, Option<StableBuf>),
> = std::collections::HashMap::with_capacity(cfg.size as usize);
loop {
while let Some(cqe) = ring.completion().next() {
handle_cqe(&mut waiters, cqe, &cfg);
}
while waiters.len() < cfg.size as usize {
let op = if waiters.is_empty() {
match receiver.next().await {
Some(work) => work,
None => {
drain(&mut ring, &mut waiters, &cfg);
return;
}
}
} else {
match receiver.try_next() {
Ok(Some(work_item)) => work_item,
Ok(None) => {
drain(&mut ring, &mut waiters, &cfg);
return;
}
Err(_) => break,
}
};
let Op {
mut work,
sender,
buffer,
} = op;
let work_id = next_work_id;
next_work_id += 1;
if next_work_id == TIMEOUT_WORK_ID {
next_work_id = 0;
}
work = work.user_data(work_id);
waiters.insert(work_id, (sender, buffer));
if let Some(timeout) = &cfg.op_timeout {
work = work.flags(io_uring::squeue::Flags::IO_LINK);
let timeout = Timespec::new()
.sec(timeout.as_secs())
.nsec(timeout.subsec_nanos());
let timeout = LinkTimeout::new(&timeout)
.build()
.user_data(TIMEOUT_WORK_ID);
unsafe {
let mut sq = ring.submission();
sq.push(&work).expect("unable to push to queue");
sq.push(&timeout).expect("unable to push timeout to queue");
}
} else {
unsafe {
ring.submission()
.push(&work)
.expect("unable to push to queue");
}
}
}
metrics.pending_operations.set(waiters.len() as _);
submit_and_wait(&mut ring, 1, cfg.force_poll).expect("unable to submit to ring");
}
}
#[allow(clippy::type_complexity)]
fn drain(
ring: &mut IoUring,
waiters: &mut HashMap<u64, (oneshot::Sender<(i32, Option<StableBuf>)>, Option<StableBuf>)>,
cfg: &Config,
) {
let pending = if cfg.op_timeout.is_some() {
waiters.len() * 2
} else {
waiters.len()
};
submit_and_wait(ring, pending, cfg.shutdown_timeout).expect("unable to submit to ring");
while let Some(cqe) = ring.completion().next() {
handle_cqe(waiters, cqe, cfg);
}
}
fn submit_and_wait(
ring: &mut IoUring,
want: usize,
timeout: Option<Duration>,
) -> Result<bool, std::io::Error> {
if let Some(timeout) = timeout {
let ts = Timespec::new()
.sec(timeout.as_secs())
.nsec(timeout.subsec_nanos());
let args = SubmitArgs::new().timespec(&ts);
match ring.submitter().submit_with_args(want, &args) {
Ok(_) => Ok(true),
Err(err) if err.raw_os_error() == Some(libc::ETIME) => Ok(false),
Err(err) => Err(err),
}
} else {
ring.submit_and_wait(want).map(|_| true)
}
}
pub fn should_retry(return_value: i32) -> bool {
return_value == -libc::EAGAIN || return_value == -libc::EWOULDBLOCK
}
#[cfg(test)]
mod tests {
use crate::iouring::{Config, Op};
use futures::{
channel::{
mpsc::channel,
oneshot::{self, Canceled},
},
executor::block_on,
SinkExt as _,
};
use io_uring::{
opcode,
types::{Fd, Timespec},
};
use prometheus_client::registry::Registry;
use std::{
os::{fd::AsRawFd, unix::net::UnixStream},
sync::Arc,
time::Duration,
};
async fn recv_then_send(cfg: Config, should_succeed: bool) {
let (mut submitter, receiver) = channel(0);
let metrics = Arc::new(super::Metrics::new(&mut Registry::default()));
let handle = tokio::spawn(super::run(cfg, metrics.clone(), receiver));
let (left_pipe, right_pipe) = UnixStream::pair().unwrap();
let msg = b"hello".to_vec();
let mut buf = vec![0; msg.len()];
let recv =
opcode::Recv::new(Fd(left_pipe.as_raw_fd()), buf.as_mut_ptr(), buf.len() as _).build();
let (recv_tx, recv_rx) = oneshot::channel();
submitter
.send(crate::iouring::Op {
work: recv,
sender: recv_tx,
buffer: Some(buf.into()),
})
.await
.expect("failed to send work");
while metrics.pending_operations.get() == 0 {
tokio::time::sleep(Duration::from_millis(100)).await;
}
let write =
opcode::Write::new(Fd(right_pipe.as_raw_fd()), msg.as_ptr(), msg.len() as _).build();
let (write_tx, write_rx) = oneshot::channel();
submitter
.send(crate::iouring::Op {
work: write,
sender: write_tx,
buffer: Some(msg.into()),
})
.await
.expect("failed to send work");
if should_succeed {
let (result, _) = recv_rx.await.expect("failed to receive result");
assert!(result > 0, "recv failed: {result}");
let (result, _) = write_rx.await.expect("failed to receive result");
assert!(result > 0, "write failed: {result}");
} else {
let _ = recv_rx.await;
let _ = write_rx.await;
}
drop(submitter);
handle.await.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_force_poll_enabled() {
let cfg = Config {
force_poll: Some(Duration::from_millis(10)),
..Default::default()
};
recv_then_send(cfg, true).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_force_poll_disabled() {
let cfg = Config {
force_poll: None,
..Default::default()
};
let timeout = tokio::time::timeout(Duration::from_secs(2), recv_then_send(cfg, false));
assert!(
timeout.await.is_err(),
"recv_then_send completed unexpectedly"
);
}
#[tokio::test]
async fn test_timeout() {
let cfg = super::Config {
op_timeout: Some(std::time::Duration::from_secs(1)),
..Default::default()
};
let (mut submitter, receiver) = channel(1);
let metrics = Arc::new(super::Metrics::new(&mut Registry::default()));
let handle = tokio::spawn(super::run(cfg, metrics, receiver));
let (pipe_left, _pipe_right) = UnixStream::pair().unwrap();
let mut buf = vec![0; 8];
let work =
opcode::Recv::new(Fd(pipe_left.as_raw_fd()), buf.as_mut_ptr(), buf.len() as _).build();
let (tx, rx) = oneshot::channel();
submitter
.send(crate::iouring::Op {
work,
sender: tx,
buffer: Some(buf.into()),
})
.await
.expect("failed to send work");
let (result, _) = rx.await.expect("failed to receive result");
assert_eq!(result, -libc::ETIMEDOUT);
drop(submitter);
handle.await.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_shutdown_no_timeout() {
let cfg = super::Config {
shutdown_timeout: None,
..Default::default()
};
let (mut submitter, receiver) = channel(1);
let metrics = Arc::new(super::Metrics::new(&mut Registry::default()));
let handle = tokio::spawn(super::run(cfg, metrics, receiver));
let timeout = Timespec::new().sec(3);
let timeout = opcode::Timeout::new(&timeout).build();
let (tx, rx) = oneshot::channel();
submitter
.send(Op {
work: timeout,
sender: tx,
buffer: None,
})
.await
.unwrap();
drop(submitter);
let (result, _) = rx.await.unwrap();
assert_eq!(result, -libc::ETIME);
handle.await.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_shutdown_timeout() {
let cfg = super::Config {
shutdown_timeout: Some(Duration::from_secs(1)),
..Default::default()
};
let (mut submitter, receiver) = channel(1);
let metrics = Arc::new(super::Metrics::new(&mut Registry::default()));
let handle = tokio::spawn(super::run(cfg, metrics, receiver));
let timeout = Timespec::new().sec(5_000);
let timeout = opcode::Timeout::new(&timeout).build();
let (tx, rx) = oneshot::channel();
submitter
.send(Op {
work: timeout,
sender: tx,
buffer: None,
})
.await
.unwrap();
drop(submitter);
let err = rx.await.unwrap_err();
assert!(matches!(err, Canceled { .. }));
handle.await.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_linked_timeout_ensure_enough_capacity() {
let cfg = super::Config {
size: 8,
op_timeout: Some(Duration::from_millis(5)),
..Default::default()
};
let (mut submitter, receiver) = channel(8);
let metrics = Arc::new(super::Metrics::new(&mut Registry::default()));
let handle = tokio::spawn(super::run(cfg, metrics, receiver));
let total = 64usize;
let mut rxs = Vec::with_capacity(total);
for _ in 0..total {
let nop = opcode::Nop::new().build();
let (tx, rx) = oneshot::channel();
submitter
.send(Op {
work: nop,
sender: tx,
buffer: None,
})
.await
.unwrap();
rxs.push(rx);
}
for rx in rxs {
let (res, _) = rx.await.unwrap();
assert_eq!(res, 0, "NOP op failed: {res}");
}
drop(submitter);
handle.await.unwrap();
}
#[tokio::test]
async fn test_single_issuer() {
let cfg = super::Config {
single_issuer: true,
..Default::default()
};
let (mut sender, receiver) = channel(1);
let metrics = Arc::new(super::Metrics::new(&mut Registry::default()));
let uring_thread = std::thread::spawn(move || block_on(super::run(cfg, metrics, receiver)));
let (tx, rx) = oneshot::channel();
sender
.send(Op {
work: opcode::Nop::new().build(),
sender: tx,
buffer: None,
})
.await
.unwrap();
let (result, _) = rx.await.unwrap();
assert_eq!(result, 0);
drop(sender);
uring_thread.join().unwrap();
}
}