#![cfg(all(target_os = "linux", feature = "async"))]
#![allow(dead_code)]
use crate::{Error, Result};
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::io::unix::AsyncFd;
use tokio::sync::{mpsc, oneshot};
pub(crate) enum Op {
Write {
fd: RawFd,
buf_ptr: usize,
buf_len: usize,
offset: u64,
reply: oneshot::Sender<i32>,
},
Read {
fd: RawFd,
buf_ptr: usize,
buf_len: usize,
offset: u64,
reply: oneshot::Sender<i32>,
},
Fdatasync {
fd: RawFd,
reply: oneshot::Sender<i32>,
},
Shutdown,
}
pub(crate) struct AsyncIoUring {
submit_tx: mpsc::UnboundedSender<Op>,
shutdown: AtomicBool,
poisoned: Arc<AtomicBool>,
join: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
}
impl AsyncIoUring {
pub(crate) fn new(queue_depth: u32) -> Result<Self> {
let mut probe_builder = io_uring::IoUring::builder();
crate::platform::iouring_features::apply(
&mut probe_builder,
crate::platform::iouring_features::RingMode::Async,
);
match probe_builder.build(queue_depth) {
Ok(_probe) => {}
Err(source) => return Err(Error::IoUringSetupFailed { source }),
}
let eventfd = create_eventfd()?;
let eventfd_raw = eventfd.into_raw_fd();
let (tx, rx) = mpsc::unbounded_channel::<Op>();
let poisoned = Arc::new(AtomicBool::new(false));
let join = tokio::task::spawn(async move {
owner_main(queue_depth, eventfd_raw, rx).await;
});
Ok(Self {
submit_tx: tx,
shutdown: AtomicBool::new(false),
poisoned,
join: std::sync::Mutex::new(Some(join)),
})
}
pub(crate) fn is_poisoned(&self) -> bool {
self.poisoned.load(Ordering::Acquire)
}
pub(crate) async fn submit(&self, op: Op, reply: oneshot::Receiver<i32>) -> Result<i32> {
if self.is_poisoned() {
return Err(Error::HandlePoisoned {
reason: "io_uring completion driver panicked".to_string(),
});
}
if self.shutdown.load(Ordering::Acquire) {
return Err(Error::CompletionDriverDead);
}
if self.submit_tx.send(op).is_err() {
self.poisoned.store(true, Ordering::Release);
return Err(Error::CompletionDriverDead);
}
match reply.await {
Ok(code) if code == i32::MIN => {
self.poisoned.store(true, Ordering::Release);
Err(Error::HandlePoisoned {
reason: "io_uring completion driver panicked mid-op".to_string(),
})
}
Ok(code) => Ok(code),
Err(_recv_err) => {
self.poisoned.store(true, Ordering::Release);
Err(Error::HandlePoisoned {
reason: "io_uring completion driver dropped sender".to_string(),
})
}
}
}
pub(crate) async fn shutdown(&self) {
self.shutdown.store(true, Ordering::Release);
let _ = self.submit_tx.send(Op::Shutdown);
let join_opt = match self.join.lock() {
Ok(mut g) => g.take(),
Err(p) => p.into_inner().take(),
};
if let Some(join) = join_opt {
let abort_handle = join.abort_handle();
if tokio::time::timeout(std::time::Duration::from_secs(5), join)
.await
.is_err()
{
abort_handle.abort();
}
}
}
}
impl Drop for AsyncIoUring {
fn drop(&mut self) {
if let Ok(g) = self.join.get_mut() {
if let Some(j) = g.take() {
j.abort();
}
}
}
}
async fn owner_main(queue_depth: u32, eventfd_raw: RawFd, rx: mpsc::UnboundedReceiver<Op>) {
owner_loop(queue_depth, eventfd_raw, rx).await;
}
async fn owner_loop(queue_depth: u32, eventfd_raw: RawFd, mut rx: mpsc::UnboundedReceiver<Op>) {
use std::collections::HashMap;
let owned_fd = unsafe { OwnedFd::from_raw_fd(eventfd_raw) };
let mut builder = io_uring::IoUring::builder();
crate::platform::iouring_features::apply(
&mut builder,
crate::platform::iouring_features::RingMode::Async,
);
let mut ring = match builder.build(queue_depth) {
Ok(r) => r,
Err(_) => return, };
let mut fd_registry = FdRegistry::new();
let _ = fd_registry.initial_register(&ring.submitter());
if register_eventfd_with_ring(&mut ring, owned_fd.as_raw_fd()).is_err() {
return; }
let async_fd = match AsyncFd::with_interest(owned_fd, tokio::io::Interest::READABLE) {
Ok(f) => f,
Err(_) => return,
};
let mut pending: HashMap<u64, oneshot::Sender<i32>> = HashMap::new();
let mut next_id: u64 = 1;
loop {
tokio::select! {
biased;
maybe_op = rx.recv() => {
match maybe_op {
Some(Op::Shutdown) | None => {
drain_completions_into(&mut ring, &mut pending);
return;
}
Some(op) => {
let id = next_id;
next_id = next_id.wrapping_add(1);
if id == 0 { next_id = 1; } push_sqe_for(&mut ring, id, &op, &mut fd_registry);
match op {
Op::Write { reply, .. }
| Op::Read { reply, .. }
| Op::Fdatasync { reply, .. } => {
let _ = pending.insert(id, reply);
}
Op::Shutdown => {}
}
let _ = ring.submit();
}
}
}
ready_result = async_fd.readable() => {
let mut ready_guard = match ready_result {
Ok(g) => g,
Err(_) => continue,
};
clear_eventfd(async_fd.get_ref().as_raw_fd());
drain_completions_into(&mut ring, &mut pending);
ready_guard.clear_ready();
}
}
}
}
fn drain_completions_into(
ring: &mut io_uring::IoUring,
pending: &mut std::collections::HashMap<u64, oneshot::Sender<i32>>,
) {
loop {
let cqe = match ring.completion().next() {
Some(c) => c,
None => break,
};
let id = cqe.user_data();
let result = cqe.result();
if let Some(tx) = pending.remove(&id) {
let _ = tx.send(result);
}
}
}
struct FdRegistry {
slots: Vec<RawFd>,
fd_to_slot: std::collections::HashMap<RawFd, u32>,
registered: bool,
}
const SLOT_TABLE_SIZE: usize = 16;
impl FdRegistry {
fn new() -> Self {
Self {
slots: vec![-1; SLOT_TABLE_SIZE],
fd_to_slot: std::collections::HashMap::new(),
registered: false,
}
}
fn initial_register(&mut self, submitter: &io_uring::Submitter<'_>) -> std::io::Result<()> {
submitter.register_files(&self.slots)?;
self.registered = true;
Ok(())
}
fn try_get_or_register(
&mut self,
submitter: &io_uring::Submitter<'_>,
fd: RawFd,
) -> Option<u32> {
if !self.registered {
return None;
}
if let Some(&slot) = self.fd_to_slot.get(&fd) {
return Some(slot);
}
let slot_idx = self.slots.iter().position(|&s| s == -1)?;
let update = [fd];
let updated = submitter
.register_files_update(slot_idx as u32, &update)
.ok()?;
if updated == 0 {
return None;
}
self.slots[slot_idx] = fd;
let _ = self.fd_to_slot.insert(fd, slot_idx as u32);
Some(slot_idx as u32)
}
}
fn push_sqe_for(ring: &mut io_uring::IoUring, id: u64, op: &Op, fd_registry: &mut FdRegistry) {
use io_uring::{opcode, types};
match op {
Op::Write {
fd,
buf_ptr,
buf_len,
offset,
..
} => {
let entry = if let Some(slot) = fd_registry.try_get_or_register(&ring.submitter(), *fd)
{
opcode::Write::new(types::Fixed(slot), *buf_ptr as *const u8, *buf_len as u32)
.offset(*offset)
.build()
.user_data(id)
} else {
opcode::Write::new(types::Fd(*fd), *buf_ptr as *const u8, *buf_len as u32)
.offset(*offset)
.build()
.user_data(id)
};
let _ = unsafe { ring.submission().push(&entry) };
}
Op::Read {
fd,
buf_ptr,
buf_len,
offset,
..
} => {
let entry = if let Some(slot) = fd_registry.try_get_or_register(&ring.submitter(), *fd)
{
opcode::Read::new(types::Fixed(slot), *buf_ptr as *mut u8, *buf_len as u32)
.offset(*offset)
.build()
.user_data(id)
} else {
opcode::Read::new(types::Fd(*fd), *buf_ptr as *mut u8, *buf_len as u32)
.offset(*offset)
.build()
.user_data(id)
};
let _ = unsafe { ring.submission().push(&entry) };
}
Op::Fdatasync { fd, .. } => {
let entry = if let Some(slot) = fd_registry.try_get_or_register(&ring.submitter(), *fd)
{
opcode::Fsync::new(types::Fixed(slot))
.flags(io_uring::types::FsyncFlags::DATASYNC)
.build()
.user_data(id)
} else {
opcode::Fsync::new(types::Fd(*fd))
.flags(io_uring::types::FsyncFlags::DATASYNC)
.build()
.user_data(id)
};
let _ = unsafe { ring.submission().push(&entry) };
}
Op::Shutdown => {
}
}
}
fn create_eventfd() -> Result<OwnedFd> {
let fd = unsafe { libc::eventfd(0, libc::EFD_NONBLOCK | libc::EFD_CLOEXEC) };
if fd < 0 {
return Err(Error::Io(std::io::Error::last_os_error()));
}
Ok(unsafe { OwnedFd::from_raw_fd(fd) })
}
fn register_eventfd_with_ring(
ring: &mut io_uring::IoUring,
eventfd_raw: RawFd,
) -> std::io::Result<()> {
ring.submitter().register_eventfd(eventfd_raw)
}
fn clear_eventfd(fd: RawFd) {
let mut buf: u64 = 0;
let _ = unsafe {
libc::read(
fd,
&mut buf as *mut u64 as *mut libc::c_void,
std::mem::size_of::<u64>(),
)
};
}
#[cfg(test)]
mod tests {
use super::*;
fn ring_or_skip() -> Option<AsyncIoUring> {
AsyncIoUring::new(8).ok()
}
async fn with_timeout<F, T>(fut: F) -> T
where
F: std::future::Future<Output = T>,
{
const TIMEOUT_SECS: u64 = 15;
match tokio::time::timeout(std::time::Duration::from_secs(TIMEOUT_SECS), fut).await {
Ok(v) => v,
Err(_) => panic!(
"test exceeded {TIMEOUT_SECS}s timeout — likely a hang in the completion driver"
),
}
}
#[tokio::test]
async fn construction_returns_or_skips() {
with_timeout(async {
let _ring = ring_or_skip();
})
.await;
}
#[tokio::test]
async fn shutdown_is_clean() {
with_timeout(async {
let Some(ring) = ring_or_skip() else { return };
ring.shutdown().await;
let (rt, rr) = oneshot::channel();
let result = ring.submit(Op::Fdatasync { fd: -1, reply: rt }, rr).await;
assert!(matches!(result, Err(Error::CompletionDriverDead)));
})
.await;
}
#[tokio::test]
async fn poisoned_flag_short_circuits_submit() {
with_timeout(async {
let Some(ring) = ring_or_skip() else { return };
ring.poisoned.store(true, Ordering::Release);
let (rt, rr) = oneshot::channel();
let result = ring.submit(Op::Fdatasync { fd: -1, reply: rt }, rr).await;
assert!(matches!(result, Err(Error::HandlePoisoned { .. })));
})
.await;
}
#[tokio::test]
async fn dropped_receiver_is_handled_gracefully() {
with_timeout(async {
let Some(ring) = ring_or_skip() else { return };
let (rt, rr) = oneshot::channel::<i32>();
drop(rr);
let (rt2, rr2) = oneshot::channel::<i32>();
let _ = ring.submit(Op::Fdatasync { fd: -1, reply: rt2 }, rr2).await;
ring.shutdown().await;
let _ = (rt,); })
.await;
}
#[tokio::test]
async fn aborted_owner_task_translates_to_clean_error() {
let Some(ring) = ring_or_skip() else { return };
{
let mut g = ring.join.lock().expect("ring.join mutex poisoned");
if let Some(j) = g.take() {
j.abort();
let _ = j.await; }
}
let (rt, rr) = oneshot::channel::<i32>();
let result = tokio::time::timeout(
std::time::Duration::from_secs(2),
ring.submit(Op::Fdatasync { fd: -1, reply: rt }, rr),
)
.await;
assert!(result.is_ok(), "submit hung after owner abort");
let inner = result.expect("not timeout");
assert!(
matches!(
inner,
Err(Error::CompletionDriverDead) | Err(Error::HandlePoisoned { .. })
),
"expected poisoned/dead error, got {inner:?}"
);
}
#[tokio::test]
async fn fdatasync_against_invalid_fd_returns_error_not_hang() {
let Some(ring) = ring_or_skip() else { return };
let (rt, rr) = oneshot::channel();
let result = tokio::time::timeout(
std::time::Duration::from_secs(2),
ring.submit(Op::Fdatasync { fd: -1, reply: rt }, rr),
)
.await;
assert!(
result.is_ok(),
"submit on invalid fd hung — driver isn't draining CQ correctly"
);
ring.shutdown().await;
}
#[tokio::test]
async fn concurrent_submits_resolve_cleanly_on_owner_abort() {
let Some(ring) = ring_or_skip() else { return };
let ring = std::sync::Arc::new(ring);
const SUBMITTERS: usize = 16;
let mut handles = Vec::with_capacity(SUBMITTERS);
for _ in 0..SUBMITTERS {
let ring = std::sync::Arc::clone(&ring);
handles.push(tokio::spawn(async move {
let (rt, rr) = oneshot::channel::<i32>();
tokio::time::timeout(
std::time::Duration::from_secs(5),
ring.submit(Op::Fdatasync { fd: -1, reply: rt }, rr),
)
.await
}));
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
{
let mut g = ring.join.lock().expect("ring.join mutex poisoned");
if let Some(j) = g.take() {
j.abort();
let _ = j.await;
}
}
for h in handles {
let outer = h.await.expect("submitter task panicked");
let inner = outer.expect("submitter timeout — owner abort didn't propagate within 5s");
match inner {
Ok(rc) => {
assert!(rc < 0, "expected -EBADF or error result, got rc={rc}");
}
Err(Error::CompletionDriverDead) | Err(Error::HandlePoisoned { .. }) => {}
other => panic!("unexpected submitter result: {other:?}"),
}
}
}
}