use io_uring::{opcode::AsyncCancel, IoUring};
use std::cell::RefCell;
use std::future::Future;
use std::os::unix::prelude::{AsRawFd, RawFd};
use std::rc::Rc;
use tokio::io::unix::AsyncFd;
thread_local! {
static URING: Rc<IoUringAsync> = {
Rc::new(IoUringAsync::new_default().unwrap())
};
}
pub fn get_ring() -> Rc<IoUringAsync> {
let mut u = None;
URING.with(|u_| u = Some(u_.clone()));
u.unwrap()
}
enum Lifecycle<C: cqueue::Entry> {
Submitted,
Waiting(std::task::Waker),
Completed(C),
}
pub struct Op<C: cqueue::Entry> {
inner: Option<OpInner<C>>,
}
impl<C: cqueue::Entry> Future for Op<C> {
type Output = C;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
std::pin::Pin::new(self.inner.as_mut().unwrap()).poll(cx)
}
}
impl<C: cqueue::Entry> Drop for Op<C> {
fn drop(&mut self) {
let inner = self.inner.take().unwrap();
let guard = inner.slab.borrow();
let index = inner.index;
match &guard[inner.index] {
Lifecycle::Completed(_) => {}
_ => {
let state_name = match &guard[inner.index] {
Lifecycle::Submitted => "Submitted",
Lifecycle::Waiting(_) => "Waiting",
Lifecycle::Completed(_) => "Completed",
};
tracing::debug!(%index, "dropping op in state {state_name}");
drop(guard);
let cancel = AsyncCancel::new(inner.index.try_into().unwrap()).build();
let mut cancel_op = get_ring().push(cancel);
let cancel_op_inner = cancel_op.inner.take().unwrap();
std::mem::forget(cancel_op);
tokio::task::spawn_local(async move {
cancel_op_inner.await;
inner.await;
});
}
}
}
}
pub struct OpInner<C: cqueue::Entry> {
slab: Rc<RefCell<slab::Slab<Lifecycle<C>>>>,
index: usize,
}
impl<C: cqueue::Entry> Future for OpInner<C> {
type Output = C;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let mut guard = self.slab.borrow_mut();
let lifecycle = &mut guard[self.index];
match lifecycle {
Lifecycle::Submitted => {
tracing::trace!(index = %self.index, "poll: submitted!");
*lifecycle = Lifecycle::Waiting(cx.waker().clone());
std::task::Poll::Pending
}
Lifecycle::Waiting(_) => {
tracing::trace!(index = %self.index, "poll: waiting!");
*lifecycle = Lifecycle::Waiting(cx.waker().clone());
std::task::Poll::Pending
}
Lifecycle::Completed(cqe) => {
tracing::trace!(index = %self.index, "poll: completed!");
std::task::Poll::Ready(cqe.clone())
}
}
}
}
impl<C: cqueue::Entry> Drop for OpInner<C> {
fn drop(&mut self) {
let mut guard = self.slab.borrow_mut();
let lifecycle = guard.remove(self.index);
match lifecycle {
Lifecycle::Completed(_) => {}
_ => {
if std::thread::panicking() {
} else {
let lifecycle_name = match lifecycle {
Lifecycle::Submitted => "Submitted",
Lifecycle::Waiting(_) => "Waiting",
Lifecycle::Completed(_) => "Completed",
};
let index = self.index;
tracing::debug!("dropping op inner {index} ({})", lifecycle_name);
panic!("Op drop occured before completion (index {})", self.index)
}
}
};
}
}
pub mod cqueue;
pub mod squeue;
pub struct IoUringAsync<
S: squeue::Entry = io_uring::squeue::Entry,
C: cqueue::Entry = io_uring::cqueue::Entry,
> {
uring: Rc<IoUring<S, C>>,
slab: Rc<RefCell<slab::Slab<Lifecycle<C>>>>,
}
impl<S: squeue::Entry, C: cqueue::Entry> AsRawFd for IoUringAsync<S, C> {
fn as_raw_fd(&self) -> RawFd {
self.uring.as_raw_fd()
}
}
impl IoUringAsync<io_uring::squeue::Entry, io_uring::cqueue::Entry> {
pub fn new_default() -> std::io::Result<Self> {
let mut entries = 512;
if let Ok(env_entries) = std::env::var("IO_URING_ENTRIES") {
entries = env_entries
.parse()
.expect("$IO_URING_ENTRIES must be a number");
}
eprintln!(
"==== IO_URING RING SIZE: {} (override with $IO_URING_ENTRIES)",
entries
);
Self::new(entries)
}
pub fn new(entries: u32) -> std::io::Result<Self> {
let mut builder = io_uring::IoUring::builder();
let sqpoll_enabled = matches!(
std::env::var("IO_URING_SQPOLL").as_deref(),
Ok("1") | Ok("true")
);
eprintln!("==== SQPOLL: {sqpoll_enabled} (override with $IO_URING_SQPOLL=1)");
let mut sqpoll_idle_ms = 200;
if let Ok(env_sqpoll_idle_ms) = std::env::var("IO_URING_SQPOLL_IDLE_MS") {
sqpoll_idle_ms = env_sqpoll_idle_ms
.parse()
.expect("$IO_URING_SQPOLL_IDLE_MS must be a number");
}
eprintln!(
"==== SQPOLL_IDLE_MS: {} (override with $IO_URING_SQPOLL_IDLE_MS)",
sqpoll_idle_ms
);
if sqpoll_enabled {
builder.setup_sqpoll(sqpoll_idle_ms);
}
Ok(Self {
uring: Rc::new(builder.build(entries)?),
slab: Rc::new(RefCell::new(slab::Slab::new())),
})
}
}
impl<S: squeue::Entry, C: cqueue::Entry> IoUringAsync<S, C> {
pub async fn listen(uring: Rc<IoUringAsync<S, C>>) {
let async_fd = AsyncFd::new(uring).unwrap();
loop {
let mut guard = async_fd.readable().await.unwrap();
guard.get_inner().handle_cqe();
guard.clear_ready();
}
}
pub fn generic_new(entries: u32) -> std::io::Result<Self> {
Ok(Self {
uring: Rc::new(io_uring::IoUring::builder().build(entries)?),
slab: Rc::new(RefCell::new(slab::Slab::new())),
})
}
pub fn push(&self, entry: impl Into<S>) -> Op<C> {
let mut guard = self.slab.borrow_mut();
let index = guard.insert(Lifecycle::Submitted);
tracing::trace!(%index, "pushing op with index");
let entry = entry.into().user_data(index.try_into().unwrap());
while unsafe { self.uring.submission_shared().push(&entry).is_err() } {
self.uring.submit().unwrap();
}
Op {
inner: Some(OpInner {
slab: self.slab.clone(),
index,
}),
}
}
pub fn handle_cqe(&self) {
let mut guard = self.slab.borrow_mut();
while let Some(cqe) = unsafe { self.uring.completion_shared() }.next() {
let index = cqe.user_data();
tracing::trace!(%index, "received cqe for index");
let lifecycle = &mut guard[index.try_into().unwrap()];
match lifecycle {
Lifecycle::Submitted => {
*lifecycle = Lifecycle::Completed(cqe);
}
Lifecycle::Waiting(waker) => {
waker.wake_by_ref();
*lifecycle = Lifecycle::Completed(cqe);
}
Lifecycle::Completed(cqe) => {
println!(
"multishot operations not implemented: {}, {}",
cqe.user_data(),
cqe.result()
);
}
}
}
}
pub fn submit(&self) -> std::io::Result<usize> {
self.uring.submit()
}
}
#[cfg(test)]
mod tests {
use super::IoUringAsync;
use io_uring::opcode::Nop;
use send_wrapper::SendWrapper;
use std::rc::Rc;
#[test]
fn example1() {
let uring = Rc::new(IoUringAsync::new(8).unwrap());
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async move {
tokio::task::LocalSet::new()
.run_until(async {
tokio::task::spawn_local(IoUringAsync::listen(uring.clone()));
let fut1 = uring.push(Nop::new().build());
let fut2 = uring.push(Nop::new().build());
uring.submit().unwrap();
let cqe1 = fut1.await;
let cqe2 = fut2.await;
assert!(cqe1.result() >= 0, "nop error: {}", cqe1.result());
assert!(cqe2.result() >= 0, "nop error: {}", cqe2.result());
})
.await;
});
}
#[test]
fn example2() {
let uring = IoUringAsync::new(8).unwrap();
let uring = Rc::new(uring);
let uring_clone = SendWrapper::new(uring.clone());
let runtime = tokio::runtime::Builder::new_current_thread()
.on_thread_park(move || {
uring_clone.submit().unwrap();
})
.enable_all()
.build()
.unwrap();
runtime.block_on(async move {
tokio::task::LocalSet::new()
.run_until(async {
tokio::task::spawn_local(IoUringAsync::listen(uring.clone()));
let cqe = uring.push(Nop::new().build()).await;
assert!(cqe.result() >= 0, "nop error: {}", cqe.result());
})
.await;
});
}
}