use core::{
cell::RefCell,
mem,
task::{Context, Poll},
};
use std::{
io,
os::unix::io::{AsRawFd, RawFd},
rc::Rc,
};
use io_uring::{IoUring, cqueue, opcode::AsyncCancel, squeue};
use slab::Slab;
use crate::{
buf::fixed::FixedBuffers,
runtime::driver::op::{Completable, Lifecycle, MultiCQEFuture, Op, Updateable},
};
pub(crate) use handle::*;
mod handle;
pub(crate) mod op;
pub(crate) struct Driver {
ops: Ops,
uring: IoUring,
fixed_buffers: Option<Rc<RefCell<dyn FixedBuffers>>>,
}
struct Ops {
lifecycle: Slab<op::Lifecycle>,
completions: Slab<op::Completion>,
}
impl Driver {
pub(crate) fn new(b: &crate::Builder) -> io::Result<Driver> {
let uring = b.urb.build(b.entries)?;
Ok(Driver {
ops: Ops::new(),
uring,
fixed_buffers: None,
})
}
fn wait(&self) -> io::Result<usize> {
self.uring.submit_and_wait(1)
}
#[allow(unused)]
pub(super) fn num_operations(&self) -> usize {
self.ops.lifecycle.len()
}
pub(crate) fn submit(&mut self) -> io::Result<()> {
while let Err(e) = self.uring.submit() {
match e.raw_os_error() {
Some(libc::EBUSY) => self.dispatch_completions(),
Some(libc::EINTR) => {}
_ => return Err(e),
}
}
self.uring.submission().sync();
Ok(())
}
pub(crate) fn dispatch_completions(&mut self) {
let mut cq = self.uring.completion();
cq.sync();
for cqe in cq {
if cqe.user_data() == u64::MAX {
continue;
}
let index = cqe.user_data() as _;
self.ops.complete(index, cqe);
}
}
pub(crate) fn register_buffers(&mut self, buffers: Rc<RefCell<dyn FixedBuffers>>) -> io::Result<()> {
unsafe { self.uring.submitter().register_buffers(buffers.borrow().iovecs()) }?;
self.fixed_buffers = Some(buffers);
Ok(())
}
pub(crate) fn unregister_buffers(&mut self, buffers: Rc<RefCell<dyn FixedBuffers>>) -> io::Result<()> {
if let Some(currently_registered) = &self.fixed_buffers {
if Rc::ptr_eq(&buffers, currently_registered) {
self.uring.submitter().unregister_buffers()?;
self.fixed_buffers = None;
return Ok(());
}
}
Err(io::Error::other("fixed buffers are not currently registered"))
}
pub(crate) fn submit_op_2(&mut self, sqe: squeue::Entry) -> usize {
let index = self.ops.insert();
let sqe = sqe.user_data(index as _);
while unsafe { self.uring.submission().push(&sqe).is_err() } {
self.submit().expect("Internal error, failed to submit ops");
}
index
}
pub(crate) fn submit_op<T, S, F>(&mut self, mut data: T, f: F, handle: WeakHandle) -> io::Result<Op<T, S>>
where
T: Completable,
F: FnOnce(&mut T) -> squeue::Entry,
{
let index = self.ops.insert();
let sqe = f(&mut data).user_data(index as _);
let op = Op::new(handle, data, index);
while unsafe { self.uring.submission().push(&sqe).is_err() } {
self.submit()?;
}
Ok(op)
}
pub(crate) fn remove_op<T, CqeType>(&mut self, op: &mut Op<T, CqeType>) {
let (lifecycle, completions) = match self.ops.get_mut(op.index()) {
Some(val) => val,
None => {
return;
}
};
match mem::replace(lifecycle, Lifecycle::Submitted) {
Lifecycle::Submitted | Lifecycle::Waiting(_) => {
*lifecycle = Lifecycle::Ignored(Box::new(op.take_data()));
}
Lifecycle::Completed(..) => {
self.ops.remove(op.index());
}
Lifecycle::CompletionList(indices) => {
let more = {
let mut list = indices.into_list(completions);
cqueue::more(list.peek_end().unwrap().flags)
};
if more {
*lifecycle = Lifecycle::Ignored(Box::new(op.take_data()));
} else {
self.ops.remove(op.index());
}
}
Lifecycle::Ignored(..) => unreachable!(),
}
}
pub(crate) fn remove_op_2<T: 'static>(&mut self, index: usize, data: T) {
let (lifecycle, completions) = match self.ops.get_mut(index) {
Some(val) => val,
None => {
return;
}
};
match mem::replace(lifecycle, Lifecycle::Submitted) {
Lifecycle::Submitted | Lifecycle::Waiting(_) => {
*lifecycle = Lifecycle::Ignored(Box::new(data));
}
Lifecycle::Completed(..) => {
self.ops.remove(index);
}
Lifecycle::CompletionList(indices) => {
let more = {
let mut list = indices.into_list(completions);
cqueue::more(list.peek_end().unwrap().flags)
};
if more {
*lifecycle = Lifecycle::Ignored(Box::new(data));
} else {
self.ops.remove(index);
}
}
Lifecycle::Ignored(..) => unreachable!(),
}
}
pub(crate) fn poll_op_2(&mut self, index: usize, cx: &mut Context<'_>) -> Poll<cqueue::Entry> {
let (lifecycle, _) = self.ops.get_mut(index).expect("invalid internal state");
match mem::replace(lifecycle, Lifecycle::Submitted) {
Lifecycle::Submitted => {
*lifecycle = Lifecycle::Waiting(cx.waker().clone());
Poll::Pending
}
Lifecycle::Waiting(waker) if !waker.will_wake(cx.waker()) => {
*lifecycle = Lifecycle::Waiting(cx.waker().clone());
Poll::Pending
}
Lifecycle::Waiting(waker) => {
*lifecycle = Lifecycle::Waiting(waker);
Poll::Pending
}
Lifecycle::Ignored(..) => unreachable!(),
Lifecycle::Completed(cqe) => {
self.ops.remove(index);
Poll::Ready(cqe)
}
Lifecycle::CompletionList(..) => {
unreachable!("No `more` flag set for SingleCQE")
}
}
}
pub(crate) fn poll_op<T>(&mut self, op: &mut Op<T>, cx: &mut Context<'_>) -> Poll<T::Output>
where
T: Unpin + 'static + Completable,
{
let (lifecycle, _) = self.ops.get_mut(op.index()).expect("invalid internal state");
match mem::replace(lifecycle, Lifecycle::Submitted) {
Lifecycle::Submitted => {
*lifecycle = Lifecycle::Waiting(cx.waker().clone());
Poll::Pending
}
Lifecycle::Waiting(waker) if !waker.will_wake(cx.waker()) => {
*lifecycle = Lifecycle::Waiting(cx.waker().clone());
Poll::Pending
}
Lifecycle::Waiting(waker) => {
*lifecycle = Lifecycle::Waiting(waker);
Poll::Pending
}
Lifecycle::Ignored(..) => unreachable!(),
Lifecycle::Completed(cqe) => {
self.ops.remove(op.index());
Poll::Ready(op.take_data().unwrap().complete(cqe.into()))
}
Lifecycle::CompletionList(..) => {
unreachable!("No `more` flag set for SingleCQE")
}
}
}
pub(crate) fn poll_multishot_op<T>(
&mut self,
op: &mut Op<T, MultiCQEFuture>,
cx: &mut Context<'_>,
) -> Poll<T::Output>
where
T: Unpin + 'static + Completable + Updateable,
{
let (lifecycle, completions) = self.ops.get_mut(op.index()).expect("invalid internal state");
match mem::replace(lifecycle, Lifecycle::Submitted) {
Lifecycle::Submitted => {
*lifecycle = Lifecycle::Waiting(cx.waker().clone());
Poll::Pending
}
Lifecycle::Waiting(waker) if !waker.will_wake(cx.waker()) => {
*lifecycle = Lifecycle::Waiting(cx.waker().clone());
Poll::Pending
}
Lifecycle::Waiting(waker) => {
*lifecycle = Lifecycle::Waiting(waker);
Poll::Pending
}
Lifecycle::Ignored(..) => unreachable!(),
Lifecycle::Completed(cqe) => {
self.ops.remove(op.index());
Poll::Ready(op.take_data().unwrap().complete(cqe.into()))
}
Lifecycle::CompletionList(indices) => {
let mut data = op.take_data().unwrap();
let mut status = Poll::Pending;
for cqe in indices.into_list(completions) {
if cqueue::more(cqe.flags) {
data.update(cqe);
} else {
status = Poll::Ready(cqe);
break;
}
}
match status {
Poll::Pending => {
op.insert_data(data);
*lifecycle = Lifecycle::Waiting(cx.waker().clone());
Poll::Pending
}
Poll::Ready(cqe) => {
self.ops.remove(op.index());
Poll::Ready(data.complete(cqe))
}
}
}
}
}
}
impl AsRawFd for Driver {
fn as_raw_fd(&self) -> RawFd {
self.uring.as_raw_fd()
}
}
impl Drop for Driver {
fn drop(&mut self) {
while !self.uring.submission().is_empty() {
self.submit().expect("Internal error when dropping driver");
}
for (_, cycle) in self.ops.lifecycle.iter_mut() {
match std::mem::replace(cycle, Lifecycle::Ignored(Box::new(()))) {
lc @ Lifecycle::Completed(_) => {
*cycle = lc;
}
Lifecycle::CompletionList(indices) => {
let mut list = indices.clone().into_list(&mut self.ops.completions);
if !io_uring::cqueue::more(list.peek_end().unwrap().flags) {
*cycle = Lifecycle::Completed(unsafe { mem::zeroed() });
}
}
_ => {
}
}
}
for (id, cycle) in self.ops.lifecycle.iter_mut() {
if let Lifecycle::Ignored(..) = cycle {
unsafe {
while self
.uring
.submission()
.push(&AsyncCancel::new(id as u64).build().user_data(u64::MAX))
.is_err()
{
self.uring
.submit_and_wait(1)
.expect("Internal error when dropping driver");
}
}
}
}
let mut id = 0;
loop {
if self.ops.lifecycle.is_empty() {
break;
}
match self.ops.lifecycle.get(id) {
Some(Lifecycle::Ignored(..)) => {
let _ = self.wait();
self.dispatch_completions();
}
Some(_) => {
let _ = self.ops.lifecycle.remove(id);
id += 1;
}
None => {
id += 1;
}
}
}
}
}
impl Ops {
fn new() -> Ops {
Ops {
lifecycle: Slab::with_capacity(64),
completions: Slab::with_capacity(64),
}
}
fn get_mut(&mut self, index: usize) -> Option<(&mut op::Lifecycle, &mut Slab<op::Completion>)> {
let completions = &mut self.completions;
self.lifecycle.get_mut(index).map(|lifecycle| (lifecycle, completions))
}
fn insert(&mut self) -> usize {
self.lifecycle.insert(op::Lifecycle::Submitted)
}
fn remove(&mut self, index: usize) {
self.lifecycle.remove(index);
}
fn complete(&mut self, index: usize, cqe: cqueue::Entry) {
let completions = &mut self.completions;
if self.lifecycle[index].complete(completions, cqe) {
self.lifecycle.remove(index);
}
}
}
impl Drop for Ops {
fn drop(&mut self) {
assert!(
self.lifecycle
.iter()
.all(|(_, cycle)| matches!(cycle, Lifecycle::Completed(_)))
)
}
}
#[cfg(test)]
mod test {
use std::rc::Rc;
use crate::runtime::CONTEXT;
use crate::runtime::driver::op::{Completable, CqeResult, Op};
use tokio_test::{assert_pending, assert_ready, task};
use super::*;
#[derive(Debug)]
pub(crate) struct Completion {
result: io::Result<u32>,
flags: u32,
data: Rc<()>,
}
impl Completable for Rc<()> {
type Output = Completion;
fn complete(self, cqe: CqeResult) -> Self::Output {
Completion {
result: cqe.result,
flags: cqe.flags,
data: self.clone(),
}
}
}
#[test]
fn op_stays_in_slab_on_drop() {
let (op, data) = init();
drop(op);
assert_eq!(2, Rc::strong_count(&data));
assert_eq!(1, num_operations());
release();
}
#[test]
fn poll_op_once() {
let (op, data) = init();
let mut op = task::spawn(op);
assert_pending!(op.poll());
assert_eq!(2, Rc::strong_count(&data));
complete(&op);
assert_eq!(1, num_operations());
assert_eq!(2, Rc::strong_count(&data));
assert!(op.is_woken());
let Completion { result, flags, data: d } = assert_ready!(op.poll());
assert_eq!(2, Rc::strong_count(&data));
assert_eq!(0, result.unwrap());
assert_eq!(0, flags);
drop(d);
assert_eq!(1, Rc::strong_count(&data));
drop(op);
assert_eq!(0, num_operations());
release();
}
#[test]
fn poll_op_twice() {
{
let (op, ..) = init();
let mut op = task::spawn(op);
assert_pending!(op.poll());
assert_pending!(op.poll());
complete(&op);
assert!(op.is_woken());
let Completion { result, flags, .. } = assert_ready!(op.poll());
assert_eq!(0, result.unwrap());
assert_eq!(0, flags);
}
release();
}
#[test]
fn poll_change_task() {
{
let (op, ..) = init();
let mut op = task::spawn(op);
assert_pending!(op.poll());
let op = op.into_inner();
let mut op = task::spawn(op);
assert_pending!(op.poll());
complete(&op);
assert!(op.is_woken());
let Completion { result, flags, .. } = assert_ready!(op.poll());
assert_eq!(0, result.unwrap());
assert_eq!(0, flags);
}
release();
}
#[test]
fn complete_before_poll() {
let (op, data) = init();
let mut op = task::spawn(op);
complete(&op);
assert_eq!(1, num_operations());
assert_eq!(2, Rc::strong_count(&data));
let Completion { result, flags, .. } = assert_ready!(op.poll());
assert_eq!(0, result.unwrap());
assert_eq!(0, flags);
drop(op);
assert_eq!(0, num_operations());
release();
}
#[test]
fn complete_after_drop() {
let (op, data) = init();
let index = op.index();
drop(op);
assert_eq!(2, Rc::strong_count(&data));
assert_eq!(1, num_operations());
CONTEXT.with(|cx| {
cx.handle()
.unwrap()
.inner
.borrow_mut()
.ops
.complete(index, unsafe { mem::zeroed() })
});
assert_eq!(1, Rc::strong_count(&data));
assert_eq!(0, num_operations());
release();
}
fn init() -> (Op<Rc<()>>, Rc<()>) {
let driver = Driver::new(&crate::builder()).unwrap();
let data = Rc::new(());
let op = CONTEXT.with(|cx| {
cx.set_handle(driver.into());
let driver = cx.handle().unwrap();
let index = driver.inner.borrow_mut().ops.insert();
Op::new((&driver).into(), data.clone(), index)
});
(op, data)
}
fn num_operations() -> usize {
CONTEXT.with(|cx| cx.handle().unwrap().inner.borrow().num_operations())
}
fn complete(op: &Op<Rc<()>>) {
let cqe = unsafe { mem::zeroed() };
CONTEXT.with(|cx| {
let driver = cx.handle().unwrap();
driver.inner.borrow_mut().ops.complete(op.index(), cqe);
});
}
fn release() {
CONTEXT.with(|cx| {
let driver = cx.handle().unwrap();
driver.inner.borrow_mut().ops.lifecycle.clear();
driver.inner.borrow_mut().ops.completions.clear();
cx.unset_driver();
});
}
}