use std::cell::Cell;
use std::fmt;
use std::io;
use std::mem;
use std::ops::{Deref, DerefMut};
use std::ptr;
use std::rc::Rc;
use std::sync::atomic;
use std::sync::atomic::AtomicU16;
use io_uring::types::BufRingEntry;
type Bgid = u16; type Bid = u16;
#[derive(Copy, Clone)]
pub(crate) struct Builder {
bgid: Bgid,
ring_entries: u16,
buf_cnt: u16,
buf_len: usize,
}
impl Builder {
pub fn new(bgid: Bgid) -> Builder {
Builder {
bgid,
ring_entries: 128,
buf_cnt: 0, buf_len: 4096,
}
}
pub fn ring_entries(mut self, ring_entries: u16) -> Builder {
self.ring_entries = ring_entries;
self
}
pub fn buf_cnt(mut self, buf_cnt: u16) -> Builder {
self.buf_cnt = buf_cnt;
self
}
pub fn buf_len(mut self, buf_len: usize) -> Builder {
self.buf_len = buf_len;
self
}
pub fn build(&self) -> io::Result<BufRing> {
let mut b: Builder = *self;
if b.buf_cnt == 0 || b.ring_entries < b.buf_cnt {
let max = std::cmp::max(b.ring_entries, b.buf_cnt);
b.buf_cnt = max;
b.ring_entries = max;
}
if b.ring_entries > (1 << 15) {
return Err(io::Error::new(
io::ErrorKind::Other,
"ring_entries exceeded 32768",
));
}
b.ring_entries = b.ring_entries.next_power_of_two();
let inner = InnerBufRing::new(b.bgid, b.ring_entries, b.buf_cnt, b.buf_len)?;
Ok(BufRing {
inner: Rc::new(inner),
})
}
}
#[derive(Clone)]
pub(crate) struct BufRing {
inner: Rc<InnerBufRing>,
}
impl BufRing {
pub fn get_buf(&self, len: usize, bid: u16) -> Buf {
self.inner.get_buf(self.clone(), len, bid)
}
pub fn bgid(&self) -> Bgid {
self.inner.bgid()
}
pub fn ring_entries(&self) -> u16 {
self.inner.ring_entries()
}
pub fn as_ptr(&self) -> *const libc::c_void {
self.inner.ring_start.as_ptr()
}
pub fn drop_buf(&self, bid: Bid) {
self.inner.drop_buf(bid);
}
}
pub(crate) struct Buf {
buf_ring: BufRing,
len: usize,
bid: Bid,
}
impl Buf {
fn new(buf_ring: BufRing, bid: Bid, len: usize) -> Self {
assert!(len <= buf_ring.inner.buf_capacity());
Self { buf_ring, len, bid }
}
fn as_slice_mut(&mut self) -> &mut [u8] {
let p = self.buf_ring.inner.stable_ptr(self.bid);
unsafe { std::slice::from_raw_parts_mut(p as *mut _, self.len) }
}
fn as_slice(&self) -> &[u8] {
let p = self.buf_ring.inner.stable_ptr(self.bid);
unsafe { std::slice::from_raw_parts(p, self.len) }
}
}
impl fmt::Debug for Buf {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Buf")
.field("bgid", &self.buf_ring.inner.bgid())
.field("bid", &self.bid)
.field("len", &self.len)
.field("cap", &self.buf_ring.inner.buf_capacity())
.finish()
}
}
impl Deref for Buf {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.as_slice()
}
}
impl DerefMut for Buf {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_slice_mut()
}
}
impl Drop for Buf {
fn drop(&mut self) {
self.buf_ring.inner.drop_buf(self.bid);
}
}
struct InnerBufRing {
bgid: Bgid,
ring_entries_mask: u16,
buf_cnt: u16, buf_len: usize,
ring_start: Mmap,
buf_list: Vec<Vec<u8>>,
local_tail: Cell<u16>,
shared_tail: *const AtomicU16,
}
impl InnerBufRing {
fn new(
bgid: Bgid,
ring_entries: u16,
buf_cnt: u16,
buf_len: usize,
) -> io::Result<InnerBufRing> {
if (buf_cnt == 0)
|| (buf_cnt > ring_entries)
|| (buf_len == 0)
|| ((ring_entries & (ring_entries - 1)) != 0)
{
return Err(io::Error::from(io::ErrorKind::InvalidInput));
}
let entry_size = mem::size_of::<BufRingEntry>();
assert_eq!(entry_size, 16);
let ring_size = entry_size * (ring_entries as usize);
let ring_start = Mmap::new(ring_size)?;
ring_start.dontfork()?;
let buf_list: Vec<Vec<u8>> = {
let mut bp = Vec::with_capacity(buf_cnt as _);
for _ in 0..buf_cnt {
bp.push(vec![0; buf_len]);
}
bp
};
let shared_tail = unsafe { BufRingEntry::tail(ring_start.as_ptr() as *const BufRingEntry) }
as *const AtomicU16;
let buf_ring = InnerBufRing {
bgid,
ring_entries_mask: ring_entries - 1,
buf_cnt,
buf_len,
ring_start,
buf_list,
local_tail: Cell::new(0),
shared_tail,
};
for bid in 0..buf_cnt {
buf_ring.push(bid);
}
buf_ring.sync();
Ok(buf_ring)
}
fn push(&self, bid: Bid) {
assert!(bid < self.buf_cnt);
let old_tail = self.local_tail.get();
self.local_tail.set(old_tail + 1);
let ring_idx = old_tail & self.mask();
let entries = self.ring_start.as_mut_ptr() as *mut BufRingEntry;
let re = unsafe { &mut *entries.add(ring_idx as usize) };
re.set_addr(self.stable_ptr(bid) as _);
re.set_len(self.buf_len as _);
re.set_bid(bid);
}
fn sync(&self) {
unsafe {
(*self.shared_tail).store(self.local_tail.get(), atomic::Ordering::Release);
}
}
fn drop_buf(&self, bid: Bid) {
self.push(bid);
self.sync();
}
fn stable_ptr(&self, bid: Bid) -> *const u8 {
self.buf_list[bid as usize].as_ptr()
}
fn ring_entries(&self) -> u16 {
self.ring_entries_mask + 1
}
fn mask(&self) -> u16 {
self.ring_entries_mask
}
fn buf_capacity(&self) -> usize {
self.buf_len as _
}
fn bgid(&self) -> Bgid {
self.bgid
}
fn get_buf(&self, buf_ring: BufRing, len: usize, bid: u16) -> Buf {
assert!(len <= self.buf_len);
Buf::new(buf_ring, bid, len)
}
}
struct Mmap {
addr: ptr::NonNull<libc::c_void>,
len: usize,
}
impl Mmap {
fn new(len: usize) -> io::Result<Mmap> {
unsafe {
match libc::mmap(
ptr::null_mut(),
len,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_ANONYMOUS | libc::MAP_SHARED | libc::MAP_POPULATE,
-1,
0,
) {
libc::MAP_FAILED => Err(io::Error::last_os_error()),
addr => {
let addr = ptr::NonNull::new_unchecked(addr);
Ok(Mmap { addr, len })
}
}
}
}
fn dontfork(&self) -> io::Result<()> {
match unsafe { libc::madvise(self.addr.as_ptr(), self.len, libc::MADV_DONTFORK) } {
0 => Ok(()),
_ => Err(io::Error::last_os_error()),
}
}
#[inline]
pub fn as_ptr(&self) -> *const libc::c_void {
self.addr.as_ptr()
}
#[inline]
fn as_mut_ptr(&self) -> *mut libc::c_void {
self.addr.as_ptr()
}
}
impl Drop for Mmap {
fn drop(&mut self) {
unsafe {
libc::munmap(self.addr.as_ptr(), self.len);
}
}
}