use std::sync::Arc;
use std::time::Duration;
use crate::async_qp::AsyncQp;
use crate::mr::{AccessFlags, OwnedMemoryRegion};
use crate::mw::MemoryWindow;
use crate::pd::ProtectionDomain;
use crate::wc::WorkCompletion;
use crate::wr::{RecvWr, SendFlags, SendWr, Sge, WrOpcode};
pub(crate) const WR_ID_CREDIT_FLAG: u64 = 1 << 63;
pub(crate) const WR_ID_PADDING_SENTINEL: u64 = u64::MAX - 20;
pub(crate) const RING_TOKEN_VERSION: u8 = 1;
pub(crate) const RING_TOKEN_SIZE: usize = 20;
#[repr(C, packed)]
pub(crate) struct RingToken {
pub(crate) version: u8,
pub(crate) _reserved: [u8; 3],
pub(crate) ring_va: u64,
pub(crate) mw_rkey: u32,
pub(crate) capacity: u32,
}
const _: () = assert!(std::mem::size_of::<RingToken>() == RING_TOKEN_SIZE);
impl RingToken {
pub(crate) fn to_bytes(&self) -> [u8; RING_TOKEN_SIZE] {
let mut buf = [0u8; RING_TOKEN_SIZE];
buf[0] = self.version;
buf[4..12].copy_from_slice(&self.ring_va.to_le_bytes());
buf[12..16].copy_from_slice(&self.mw_rkey.to_le_bytes());
buf[16..20].copy_from_slice(&self.capacity.to_le_bytes());
buf
}
pub(crate) fn from_bytes(buf: &[u8; RING_TOKEN_SIZE]) -> Self {
Self {
version: buf[0],
_reserved: [buf[1], buf[2], buf[3]],
ring_va: u64::from_le_bytes(buf[4..12].try_into().unwrap()),
mw_rkey: u32::from_le_bytes(buf[12..16].try_into().unwrap()),
capacity: u32::from_le_bytes(buf[16..20].try_into().unwrap()),
}
}
}
pub(crate) struct RingBuffer {
pub(crate) mr: OwnedMemoryRegion,
pub(crate) capacity: usize,
pub(crate) head: usize,
pub(crate) tail: usize,
}
impl RingBuffer {
pub(crate) fn new(mr: OwnedMemoryRegion, capacity: usize) -> Self {
Self {
mr,
capacity,
head: 0,
tail: 0,
}
}
pub(crate) fn available(&self) -> usize {
self.capacity - self.used() - 1
}
pub(crate) fn used(&self) -> usize {
if self.tail >= self.head {
self.tail - self.head
} else {
self.capacity - self.head + self.tail
}
}
pub(crate) fn contiguous_free(&self) -> usize {
if self.tail >= self.head {
if self.head == 0 {
self.capacity - self.tail - 1
} else {
self.capacity - self.tail
}
} else {
self.head - self.tail - 1
}
}
pub(crate) fn reserve(&mut self, len: usize) -> Option<(usize, usize)> {
if len == 0 || len > self.available() {
return None;
}
let contig = self.contiguous_free();
if len <= contig {
let offset = self.tail;
self.tail = (self.tail + len) % self.capacity;
Some((offset, 0))
} else {
let padding = self.capacity - self.tail;
if padding + len > self.available() + 1 {
return None;
}
if self.head <= len {
return None;
}
let offset = 0;
self.tail = len;
Some((offset, padding))
}
}
pub(crate) fn release(&mut self, len: usize) {
self.head = (self.head + len) % self.capacity;
}
}
pub(crate) struct CompletionTracker {
released: Box<[bool]>,
head_slot: usize,
capacity: usize,
}
impl CompletionTracker {
pub(crate) fn new(capacity: usize) -> Self {
Self {
released: vec![false; capacity].into_boxed_slice(),
head_slot: 0,
capacity,
}
}
pub(crate) fn release(&mut self, slot: usize) -> usize {
if slot < self.capacity {
self.released[slot] = true;
}
let mut advanced = 0;
while self.released[self.head_slot] {
self.released[self.head_slot] = false;
self.head_slot = (self.head_slot + 1) % self.capacity;
advanced += 1;
}
advanced
}
}
pub(crate) fn bind_recv_mw(
qp: &AsyncQp,
pd: &Arc<ProtectionDomain>,
recv_mr: &OwnedMemoryRegion,
ring_capacity: usize,
) -> crate::Result<(MemoryWindow, u32)> {
assert!(
crate::device::supports_mw_type2(pd),
"ring transport requires Memory Window Type 2 support \
(device does not support ibv_alloc_mw Type 2)"
);
let mw = MemoryWindow::alloc(pd, crate::mw::MwType::Type2)?;
let mw_rkey = mw.rkey();
let mut bind_wr = SendWr::new(u64::MAX - 10, WrOpcode::BindMw)
.flags(SendFlags::SIGNALED)
.bind_mw(
mw.as_raw(),
mw_rkey,
recv_mr.as_raw(),
recv_mr.addr(),
ring_capacity as u64,
rdma_io_sys::ibverbs::IBV_ACCESS_REMOTE_WRITE,
);
qp.post_send_wr(&mut bind_wr)?;
Ok((mw, mw_rkey))
}
pub(crate) fn post_token_recv(
qp: &AsyncQp,
pd: &Arc<ProtectionDomain>,
) -> crate::Result<OwnedMemoryRegion> {
let token_recv_mr = pd.reg_mr_owned(vec![0u8; RING_TOKEN_SIZE], AccessFlags::LOCAL_WRITE)?;
let recv_sge = Sge::new(
token_recv_mr.addr(),
RING_TOKEN_SIZE as u32,
token_recv_mr.lkey(),
);
let mut recv_wr = RecvWr::new(u64::MAX).sg(recv_sge);
qp.post_recv_wr(&mut recv_wr)?;
Ok(token_recv_mr)
}
pub(crate) async fn complete_token_exchange(
qp: &AsyncQp,
pd: &Arc<ProtectionDomain>,
recv_mr: &OwnedMemoryRegion,
mw_rkey: u32,
token_recv_mr: &OwnedMemoryRegion,
_token_timeout: Duration,
ring_capacity: usize,
) -> crate::Result<(u64, u32, usize)> {
let our_token = RingToken {
version: RING_TOKEN_VERSION,
_reserved: [0; 3],
ring_va: recv_mr.addr(),
mw_rkey,
capacity: ring_capacity as u32,
};
let token_bytes = our_token.to_bytes();
let token_send_mr = pd.reg_mr_owned(token_bytes.to_vec(), AccessFlags::LOCAL_WRITE)?;
let send_sge = Sge::new(
token_send_mr.addr(),
RING_TOKEN_SIZE as u32,
token_send_mr.lkey(),
);
let mut send_wr = SendWr::new(u64::MAX - 2, WrOpcode::Send)
.flags(SendFlags::SIGNALED | SendFlags::INLINE)
.sg(send_sge);
qp.post_send_wr(&mut send_wr)?;
let mut wc_buf = [WorkCompletion::default(); 4];
let n = qp.recv_cq().poll(&mut wc_buf).await?;
if n > 0 && !wc_buf[0].is_success() {
return Err(crate::Error::WorkCompletion {
status: wc_buf[0].status_raw(),
vendor_err: wc_buf[0].vendor_err(),
});
}
let recv_buf: &[u8; RING_TOKEN_SIZE] = token_recv_mr
.as_slice()
.try_into()
.expect("token recv MR is exactly RING_TOKEN_SIZE");
let peer_token = RingToken::from_bytes(recv_buf);
let peer_ver = peer_token.version;
if peer_ver != RING_TOKEN_VERSION {
return Err(crate::Error::InvalidArg(format!(
"unsupported ring token version: {peer_ver}",
)));
}
let peer_cap = peer_token.capacity;
if peer_cap == 0 {
return Err(crate::Error::InvalidArg("peer ring capacity is 0".into()));
}
if peer_cap as usize > 65536 {
return Err(crate::Error::InvalidArg(format!(
"peer ring capacity too large: {peer_cap}",
)));
}
Ok((peer_token.ring_va, peer_token.mw_rkey, peer_cap as usize))
}
pub(crate) fn drain_send_cq(qp: &AsyncQp) -> crate::Result<usize> {
let mut wc = [WorkCompletion::default(); 16];
let mut total = 0;
for _ in 0..100 {
qp.send_cq().cq().req_notify(false)?;
match qp.send_cq().cq().poll(&mut wc) {
Ok(0) => {
if total > 0 {
break; }
std::hint::spin_loop();
}
Ok(n) => {
total += n;
}
Err(e) => return Err(e),
}
}
Ok(total)
}