use std::collections::VecDeque;
use std::net::SocketAddr;
use std::os::unix::io::RawFd;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::unix::AsyncFd;
use crate::async_cm::{AsyncCmId, AsyncCmListener};
use crate::async_cq::{AsyncCq, CqPollState};
use crate::async_qp::AsyncQp;
use crate::cm::{CmId, ConnParam, EventChannel, PortSpace};
use crate::mr::{AccessFlags, OwnedMemoryRegion};
use crate::mw::MemoryWindow;
use crate::pd::ProtectionDomain;
use crate::qp::QpInitAttr;
use crate::transport::{RecvCompletion, Transport, TransportBuilder};
use crate::transport_common::*;
use crate::wc::{WcOpcode, WorkCompletion};
use crate::wr::{QpType, RecvWr, SendFlags, SendWr, Sge, WrOpcode};
const READ_RING_TOKEN_VERSION: u8 = 2;
const READ_RING_TOKEN_SIZE: usize = 32;
const WR_ID_READ_SENTINEL: u64 = u64::MAX - 30;
#[repr(C, packed)]
struct ReadRingToken {
version: u8,
_reserved: [u8; 3],
ring_va: u64,
ring_capacity: u32,
ring_rkey: u32,
offset_va: u64,
offset_rkey: u32,
}
const _: () = assert!(std::mem::size_of::<ReadRingToken>() == READ_RING_TOKEN_SIZE);
impl ReadRingToken {
fn to_bytes(&self) -> [u8; READ_RING_TOKEN_SIZE] {
let mut buf = [0u8; READ_RING_TOKEN_SIZE];
buf[0] = self.version;
buf[4..12].copy_from_slice(&self.ring_va.to_le_bytes());
buf[12..16].copy_from_slice(&self.ring_capacity.to_le_bytes());
buf[16..20].copy_from_slice(&self.ring_rkey.to_le_bytes());
buf[20..28].copy_from_slice(&self.offset_va.to_le_bytes());
buf[28..32].copy_from_slice(&self.offset_rkey.to_le_bytes());
buf
}
fn from_bytes(buf: &[u8; READ_RING_TOKEN_SIZE]) -> Self {
Self {
version: buf[0],
_reserved: [buf[1], buf[2], buf[3]],
ring_va: u64::from_le_bytes(buf[4..12].try_into().unwrap()),
ring_capacity: u32::from_le_bytes(buf[12..16].try_into().unwrap()),
ring_rkey: u32::from_le_bytes(buf[16..20].try_into().unwrap()),
offset_va: u64::from_le_bytes(buf[20..28].try_into().unwrap()),
offset_rkey: u32::from_le_bytes(buf[28..32].try_into().unwrap()),
}
}
}
#[derive(Debug, Clone)]
pub struct ReadRingConfig {
pub ring_capacity: usize,
pub max_message_size: usize,
pub token_timeout: Duration,
pub max_inline_data: u32,
pub min_free_threshold: usize,
}
impl ReadRingConfig {
pub fn datagram() -> Self {
Self {
ring_capacity: 65536,
max_message_size: 1500,
token_timeout: Duration::from_secs(5),
max_inline_data: 0,
min_free_threshold: 128,
}
}
}
impl Default for ReadRingConfig {
fn default() -> Self {
Self::datagram()
}
}
pub struct ReadRingTransport {
send_cq_state: CqPollState,
recv_cq_state: CqPollState,
disconnected: bool,
peer_disconnected: bool,
virtual_idx_map: Box<[Option<(usize, usize, usize)>]>, next_virt_idx: usize,
max_outstanding: usize,
recv_arrival_seq: usize,
recv_stash: VecDeque<RecvCompletion>,
cached_remote_head: usize,
read_in_flight: bool,
remote_offset_va: u64,
remote_offset_rkey: u32,
slot_lengths: Box<[usize]>,
head_slot_idx: usize,
send_in_flight: usize,
_offset_mw: Option<MemoryWindow>,
_recv_mw: Option<MemoryWindow>,
qp: AsyncQp,
send_ring: RingBuffer,
recv_ring: RingBuffer,
read_buf: OwnedMemoryRegion,
_offset_mr: OwnedMemoryRegion,
remote_addr: u64,
remote_rkey: u32,
remote_capacity: usize,
remote_write_tail: usize,
doorbell_bufs: Box<[OwnedMemoryRegion]>,
doorbell_repost_idx: usize,
recv_tracker: CompletionTracker,
config: ReadRingConfig,
_pd: Arc<ProtectionDomain>,
cm_async_fd: AsyncFd<RawFd>,
cm_id: CmId,
event_channel: EventChannel,
}
impl ReadRingTransport {
fn offset_buffer(&self) -> &AtomicU32 {
unsafe { &*(self._offset_mr.addr() as *const AtomicU32) }
}
}
fn post_read_ring_token_recv(
qp: &AsyncQp,
pd: &Arc<ProtectionDomain>,
) -> crate::Result<OwnedMemoryRegion> {
let token_recv_mr =
pd.reg_mr_owned(vec![0u8; READ_RING_TOKEN_SIZE], AccessFlags::LOCAL_WRITE)?;
let recv_sge = Sge::new(
token_recv_mr.addr(),
READ_RING_TOKEN_SIZE as u32,
token_recv_mr.lkey(),
);
let mut recv_wr = RecvWr::new(u64::MAX).sg(recv_sge);
qp.post_recv_wr(&mut recv_wr)?;
Ok(token_recv_mr)
}
async fn complete_read_ring_token_exchange(
qp: &AsyncQp,
pd: &Arc<ProtectionDomain>,
our_token: &ReadRingToken,
token_recv_mr: &OwnedMemoryRegion,
) -> crate::Result<ReadRingToken> {
let token_bytes = our_token.to_bytes();
let token_send_mr = pd.reg_mr_owned(token_bytes.to_vec(), AccessFlags::LOCAL_WRITE)?;
let send_sge = Sge::new(
token_send_mr.addr(),
READ_RING_TOKEN_SIZE as u32,
token_send_mr.lkey(),
);
let mut send_wr = SendWr::new(u64::MAX - 2, WrOpcode::Send)
.flags(SendFlags::SIGNALED | SendFlags::INLINE)
.sg(send_sge);
qp.post_send_wr(&mut send_wr)?;
let mut wc_buf = [WorkCompletion::default(); 4];
let n = qp.recv_cq().poll(&mut wc_buf).await?;
if n > 0 && !wc_buf[0].is_success() {
return Err(crate::Error::WorkCompletion {
status: wc_buf[0].status_raw(),
vendor_err: wc_buf[0].vendor_err(),
});
}
let recv_buf: &[u8; READ_RING_TOKEN_SIZE] = token_recv_mr
.as_slice()
.try_into()
.expect("token recv MR is exactly READ_RING_TOKEN_SIZE");
let peer_token = ReadRingToken::from_bytes(recv_buf);
let peer_ver = peer_token.version;
if peer_ver != READ_RING_TOKEN_VERSION {
return Err(crate::Error::InvalidArg(format!(
"unsupported read ring token version: {peer_ver}",
)));
}
let peer_cap = peer_token.ring_capacity;
if peer_cap == 0 {
return Err(crate::Error::InvalidArg("peer ring capacity is 0".into()));
}
if peer_cap as usize > 65536 {
return Err(crate::Error::InvalidArg(format!(
"peer ring capacity too large: {peer_cap}",
)));
}
Ok(peer_token)
}
fn bind_offset_mw(
qp: &AsyncQp,
pd: &Arc<ProtectionDomain>,
offset_mr: &OwnedMemoryRegion,
offset_size: usize,
) -> crate::Result<(MemoryWindow, u32)> {
let mw = MemoryWindow::alloc(pd, crate::mw::MwType::Type2)?;
let mw_rkey = mw.rkey();
let mut bind_wr = SendWr::new(u64::MAX - 11, WrOpcode::BindMw)
.flags(SendFlags::SIGNALED)
.bind_mw(
mw.as_raw(),
mw_rkey,
offset_mr.as_raw(),
offset_mr.addr(),
offset_size as u64,
rdma_io_sys::ibverbs::IBV_ACCESS_REMOTE_READ,
);
qp.post_send_wr(&mut bind_wr)?;
Ok((mw, mw_rkey))
}
impl ReadRingTransport {
pub async fn connect(addr: &SocketAddr, config: ReadRingConfig) -> crate::Result<Self> {
if crate::device::any_device_is_iwarp() {
return Err(crate::Error::InvalidArg(
"ring transport requires InfiniBand/RoCE (iWARP detected)".into(),
));
}
let async_cm = AsyncCmId::new(PortSpace::Tcp)?;
async_cm.resolve_addr(None, addr, 2000).await?;
async_cm.resolve_route(2000).await?;
let ctx = async_cm
.verbs_context()
.ok_or(crate::Error::InvalidArg("no verbs context".into()))?;
let pd = async_cm.alloc_pd()?;
let max_outstanding = config.ring_capacity / config.max_message_size;
let send_cq_depth = (max_outstanding + 3) as i32;
let recv_cq_depth = (max_outstanding + 2) as i32;
let send_cq = AsyncCq::create_tokio(ctx.clone(), send_cq_depth)?;
let recv_cq = AsyncCq::create_tokio(ctx, recv_cq_depth)?;
let qp_attr = QpInitAttr {
qp_type: QpType::Rc,
max_send_wr: send_cq_depth as u32,
max_recv_wr: recv_cq_depth as u32,
max_send_sge: 1,
max_recv_sge: 1,
max_inline_data: config.max_inline_data.max(READ_RING_TOKEN_SIZE as u32),
sq_sig_all: true,
};
let cmqp =
async_cm.create_qp_with_cq(&pd, &qp_attr, Some(send_cq.cq()), Some(recv_cq.cq()))?;
let send_mr = pd.reg_mr_owned(vec![0u8; config.ring_capacity], AccessFlags::LOCAL_WRITE)?;
let recv_mr = pd.reg_mr_owned(
vec![0u8; config.ring_capacity],
AccessFlags::LOCAL_WRITE | AccessFlags::REMOTE_WRITE | AccessFlags::MW_BIND,
)?;
let offset_mr = pd.reg_mr_owned(
vec![0u8; 64],
AccessFlags::LOCAL_WRITE | AccessFlags::REMOTE_READ | AccessFlags::MW_BIND,
)?;
let read_buf = pd.reg_mr_owned(vec![0u8; 4], AccessFlags::LOCAL_WRITE)?;
let doorbell_bufs: Box<[OwnedMemoryRegion]> = (0..max_outstanding)
.map(|_| pd.reg_mr_owned(vec![0u8; 4], AccessFlags::LOCAL_WRITE))
.collect::<crate::Result<Vec<_>>>()?
.into_boxed_slice();
let qp = AsyncQp::new(cmqp, send_cq, recv_cq);
let token_recv_mr = post_read_ring_token_recv(&qp, &pd)?;
async_cm.connect(&ConnParam::default()).await?;
let (event_channel, cm_id) = async_cm.into_parts();
let cm_async_fd = AsyncFd::new(event_channel.fd()).map_err(crate::Error::Verbs)?;
let (recv_mw, mw1_rkey) = bind_recv_mw(&qp, &pd, &recv_mr, config.ring_capacity)?;
let (offset_mw, mw2_rkey) = bind_offset_mw(&qp, &pd, &offset_mr, 64)?;
let our_token = ReadRingToken {
version: READ_RING_TOKEN_VERSION,
_reserved: [0; 3],
ring_va: recv_mr.addr(),
ring_capacity: config.ring_capacity as u32,
ring_rkey: mw1_rkey,
offset_va: offset_mr.addr(),
offset_rkey: mw2_rkey,
};
let peer_token =
complete_read_ring_token_exchange(&qp, &pd, &our_token, &token_recv_mr).await?;
drain_send_cq(&qp)?;
for (i, mr) in doorbell_bufs.iter().enumerate() {
let sge = Sge::new(mr.addr(), 4, mr.lkey());
let mut wr = RecvWr::new(i as u64).sg(sge);
qp.post_recv_wr(&mut wr)?;
}
Ok(Self::from_parts(
qp,
cm_async_fd,
cm_id,
event_channel,
pd,
send_mr,
recv_mr,
recv_mw,
offset_mw,
offset_mr,
read_buf,
doorbell_bufs,
peer_token.ring_va,
peer_token.ring_rkey,
peer_token.ring_capacity as usize,
peer_token.offset_va,
peer_token.offset_rkey,
max_outstanding,
config,
))
}
pub async fn accept(listener: &AsyncCmListener, config: ReadRingConfig) -> crate::Result<Self> {
if crate::device::any_device_is_iwarp() {
return Err(crate::Error::InvalidArg(
"ring transport requires InfiniBand/RoCE (iWARP detected)".into(),
));
}
let conn_id = listener.get_request().await?;
let ctx = conn_id
.verbs_context()
.ok_or(crate::Error::InvalidArg("no verbs context".into()))?;
let pd = conn_id.alloc_pd()?;
let max_outstanding = config.ring_capacity / config.max_message_size;
let send_cq_depth = (max_outstanding + 3) as i32;
let recv_cq_depth = (max_outstanding + 2) as i32;
let send_cq = AsyncCq::create_tokio(ctx.clone(), send_cq_depth)?;
let recv_cq = AsyncCq::create_tokio(ctx, recv_cq_depth)?;
let qp_attr = QpInitAttr {
qp_type: QpType::Rc,
max_send_wr: send_cq_depth as u32,
max_recv_wr: recv_cq_depth as u32,
max_send_sge: 1,
max_recv_sge: 1,
max_inline_data: config.max_inline_data.max(READ_RING_TOKEN_SIZE as u32),
sq_sig_all: true,
};
let cmqp =
conn_id.create_qp_with_cq(&pd, &qp_attr, Some(send_cq.cq()), Some(recv_cq.cq()))?;
let send_mr = pd.reg_mr_owned(vec![0u8; config.ring_capacity], AccessFlags::LOCAL_WRITE)?;
let recv_mr = pd.reg_mr_owned(
vec![0u8; config.ring_capacity],
AccessFlags::LOCAL_WRITE | AccessFlags::REMOTE_WRITE | AccessFlags::MW_BIND,
)?;
let offset_mr = pd.reg_mr_owned(
vec![0u8; 64],
AccessFlags::LOCAL_WRITE | AccessFlags::REMOTE_READ | AccessFlags::MW_BIND,
)?;
let read_buf = pd.reg_mr_owned(vec![0u8; 4], AccessFlags::LOCAL_WRITE)?;
let doorbell_bufs: Box<[OwnedMemoryRegion]> = (0..max_outstanding)
.map(|_| pd.reg_mr_owned(vec![0u8; 4], AccessFlags::LOCAL_WRITE))
.collect::<crate::Result<Vec<_>>>()?
.into_boxed_slice();
let qp = AsyncQp::new(cmqp, send_cq, recv_cq);
let token_recv_mr = post_read_ring_token_recv(&qp, &pd)?;
let async_cm = listener
.complete_accept(conn_id, &ConnParam::default())
.await?;
let (event_channel, cm_id) = async_cm.into_parts();
let cm_async_fd = AsyncFd::new(event_channel.fd()).map_err(crate::Error::Verbs)?;
let (recv_mw, mw1_rkey) = bind_recv_mw(&qp, &pd, &recv_mr, config.ring_capacity)?;
let (offset_mw, mw2_rkey) = bind_offset_mw(&qp, &pd, &offset_mr, 64)?;
let our_token = ReadRingToken {
version: READ_RING_TOKEN_VERSION,
_reserved: [0; 3],
ring_va: recv_mr.addr(),
ring_capacity: config.ring_capacity as u32,
ring_rkey: mw1_rkey,
offset_va: offset_mr.addr(),
offset_rkey: mw2_rkey,
};
let peer_token =
complete_read_ring_token_exchange(&qp, &pd, &our_token, &token_recv_mr).await?;
drain_send_cq(&qp)?;
for (i, mr) in doorbell_bufs.iter().enumerate() {
let sge = Sge::new(mr.addr(), 4, mr.lkey());
let mut wr = RecvWr::new(i as u64).sg(sge);
qp.post_recv_wr(&mut wr)?;
}
Ok(Self::from_parts(
qp,
cm_async_fd,
cm_id,
event_channel,
pd,
send_mr,
recv_mr,
recv_mw,
offset_mw,
offset_mr,
read_buf,
doorbell_bufs,
peer_token.ring_va,
peer_token.ring_rkey,
peer_token.ring_capacity as usize,
peer_token.offset_va,
peer_token.offset_rkey,
max_outstanding,
config,
))
}
#[allow(clippy::too_many_arguments)]
fn from_parts(
qp: AsyncQp,
cm_async_fd: AsyncFd<RawFd>,
cm_id: CmId,
event_channel: EventChannel,
pd: Arc<ProtectionDomain>,
send_mr: OwnedMemoryRegion,
recv_mr: OwnedMemoryRegion,
recv_mw: MemoryWindow,
offset_mw: MemoryWindow,
offset_mr: OwnedMemoryRegion,
read_buf: OwnedMemoryRegion,
doorbell_bufs: Box<[OwnedMemoryRegion]>,
remote_addr: u64,
remote_rkey: u32,
remote_capacity: usize,
remote_offset_va: u64,
remote_offset_rkey: u32,
max_outstanding: usize,
config: ReadRingConfig,
) -> Self {
let ring_capacity = config.ring_capacity;
Self {
send_cq_state: CqPollState::default(),
recv_cq_state: CqPollState::default(),
disconnected: false,
peer_disconnected: false,
virtual_idx_map: vec![None; max_outstanding].into_boxed_slice(),
next_virt_idx: 0,
max_outstanding,
recv_arrival_seq: 0,
recv_stash: VecDeque::new(),
cached_remote_head: 0,
read_in_flight: false,
remote_offset_va,
remote_offset_rkey,
slot_lengths: vec![0usize; max_outstanding].into_boxed_slice(),
head_slot_idx: 0,
send_in_flight: 0,
_offset_mw: Some(offset_mw),
_recv_mw: Some(recv_mw),
qp,
send_ring: RingBuffer::new(send_mr, ring_capacity),
recv_ring: RingBuffer::new(recv_mr, ring_capacity),
read_buf,
_offset_mr: offset_mr,
remote_addr,
remote_rkey,
remote_capacity,
remote_write_tail: 0,
doorbell_bufs,
doorbell_repost_idx: 0,
recv_tracker: CompletionTracker::new(max_outstanding),
config,
_pd: pd,
cm_async_fd,
cm_id,
event_channel,
}
}
fn remote_free_space(&self) -> usize {
let used = if self.remote_write_tail >= self.cached_remote_head {
self.remote_write_tail - self.cached_remote_head
} else {
self.remote_capacity - self.cached_remote_head + self.remote_write_tail
};
self.remote_capacity.saturating_sub(used + 1)
}
fn check_cm_event(&mut self) -> bool {
match self.event_channel.try_get_event() {
Ok(ev) => {
let etype = ev.event_type();
ev.ack();
if etype == crate::cm::CmEventType::Disconnected {
self.peer_disconnected = true;
}
self.peer_disconnected
}
Err(crate::Error::WouldBlock) => false,
Err(_) => {
self.peer_disconnected = true;
true
}
}
}
fn repost_doorbell(&mut self) -> crate::Result<()> {
let idx = self.doorbell_repost_idx;
self.doorbell_repost_idx = (idx + 1) % self.doorbell_bufs.len();
let mr = &self.doorbell_bufs[idx];
let sge = Sge::new(mr.addr(), 4, mr.lkey());
let mut wr = RecvWr::new(idx as u64).sg(sge);
self.qp.post_recv_wr(&mut wr)
}
fn post_offset_read(&mut self) -> crate::Result<()> {
let sge = Sge::new(self.read_buf.addr(), 4, self.read_buf.lkey());
let mut wr = SendWr::new(WR_ID_READ_SENTINEL, WrOpcode::RdmaRead)
.flags(SendFlags::SIGNALED)
.sg(sge)
.rdma(self.remote_offset_va, self.remote_offset_rkey);
self.qp.post_send_wr(&mut wr)?;
self.read_in_flight = true;
Ok(())
}
fn update_cached_remote_head(&mut self) {
let head_val = u32::from_ne_bytes(self.read_buf.as_slice()[..4].try_into().unwrap());
self.cached_remote_head = head_val as usize;
}
fn drain_send_cq_for_read(&mut self) {
let mut wc_buf = [WorkCompletion::default(); 8];
if self.qp.send_cq().cq().req_notify(false).is_err() {
self.peer_disconnected = true;
return;
}
let n = match self.qp.send_cq().cq().poll(&mut wc_buf) {
Ok(n) => n,
Err(_) => {
self.peer_disconnected = true;
return;
}
};
for wc in &wc_buf[..n] {
if !wc.is_success() {
self.peer_disconnected = true;
return;
}
let wr_id = wc.wr_id();
if wr_id == WR_ID_READ_SENTINEL {
self.update_cached_remote_head();
self.read_in_flight = false;
} else if wr_id == WR_ID_PADDING_SENTINEL {
self.send_in_flight = self.send_in_flight.saturating_sub(1);
} else {
let data_len = wr_id as usize;
if data_len > 0 && data_len <= self.send_ring.capacity {
self.send_ring.release(data_len);
}
self.send_in_flight = self.send_in_flight.saturating_sub(1);
}
}
}
fn advance_recv_head(&mut self, contiguous: usize) {
for _ in 0..contiguous {
let slot_len = self.slot_lengths[self.head_slot_idx];
self.recv_ring.release(slot_len);
self.head_slot_idx = (self.head_slot_idx + 1) % self.max_outstanding;
}
self.offset_buffer()
.store(self.recv_ring.head as u32, Ordering::Release);
}
}
impl Transport for ReadRingTransport {
fn send_copy(&mut self, data: &[u8]) -> crate::Result<usize> {
if self.peer_disconnected {
return Err(crate::Error::WorkCompletion {
status: rdma_io_sys::ibverbs::IBV_WC_WR_FLUSH_ERR,
vendor_err: 0,
});
}
if data.is_empty() {
return Ok(0);
}
if self.read_in_flight {
self.drain_send_cq_for_read();
}
let free = self.remote_free_space();
let data_len = data
.len()
.min(self.config.max_message_size)
.min(self.remote_capacity)
.min(0xFFFF);
if free < data_len + self.config.min_free_threshold {
if !self.read_in_flight {
self.post_offset_read()?;
}
return Ok(0);
}
let (local_offset, padding) = match self.send_ring.reserve(data_len) {
Some(result) => result,
None => return Ok(0),
};
if padding > 0 {
if free < padding + data_len + self.config.min_free_threshold {
self.send_ring.tail = if local_offset == 0 {
self.send_ring.capacity - padding
} else {
local_offset
};
return Ok(0);
}
let pad_remote_offset = self.remote_write_tail;
let imm = (pad_remote_offset as u32) << 16; let mut pad_wr = SendWr::new(WR_ID_PADDING_SENTINEL, WrOpcode::RdmaWriteWithImm(imm))
.flags(SendFlags::SIGNALED)
.rdma(
self.remote_addr + pad_remote_offset as u64,
self.remote_rkey,
);
self.qp.post_send_wr(&mut pad_wr)?;
self.send_in_flight += 1;
self.remote_write_tail = 0;
}
self.send_ring.mr.as_mut_slice()[local_offset..local_offset + data_len]
.copy_from_slice(&data[..data_len]);
let remote_offset = self.remote_write_tail;
let imm = ((remote_offset as u32) << 16) | (data_len as u32);
let sge = Sge::new(
self.send_ring.mr.addr() + local_offset as u64,
data_len as u32,
self.send_ring.mr.lkey(),
);
let release_len = padding + data_len;
let mut wr = SendWr::new(release_len as u64, WrOpcode::RdmaWriteWithImm(imm))
.flags(SendFlags::SIGNALED)
.sg(sge)
.rdma(self.remote_addr + remote_offset as u64, self.remote_rkey);
self.qp.post_send_wr(&mut wr)?;
self.remote_write_tail = (remote_offset + data_len) % self.remote_capacity;
self.send_in_flight += 1;
if !self.read_in_flight {
let remaining = self.remote_free_space();
if remaining < data_len * 2 + self.config.min_free_threshold
&& self.post_offset_read().is_err()
{
self.peer_disconnected = true;
}
}
Ok(data_len)
}
fn poll_send_completion(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
loop {
let mut wc_buf = [WorkCompletion::default(); 8];
let n = match self
.qp
.poll_send_cq(cx, &mut self.send_cq_state, &mut wc_buf)
{
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Ready(Ok(n)) => n,
};
let mut got_write = false;
for wc in &wc_buf[..n] {
if !wc.is_success() {
self.peer_disconnected = true;
return Poll::Ready(Err(crate::Error::WorkCompletion {
status: wc.status_raw(),
vendor_err: wc.vendor_err(),
}));
}
let wr_id = wc.wr_id();
if wr_id == WR_ID_READ_SENTINEL {
self.update_cached_remote_head();
self.read_in_flight = false;
} else if wr_id == WR_ID_PADDING_SENTINEL {
self.send_in_flight = self.send_in_flight.saturating_sub(1);
got_write = true;
} else {
let data_len = wr_id as usize;
if data_len > 0 && data_len <= self.send_ring.capacity {
self.send_ring.release(data_len);
}
self.send_in_flight = self.send_in_flight.saturating_sub(1);
got_write = true;
}
}
if got_write {
return Poll::Ready(Ok(()));
}
}
}
fn poll_recv(
&mut self,
cx: &mut Context<'_>,
out: &mut [RecvCompletion],
) -> Poll<crate::Result<usize>> {
let mut filled = 0;
while filled < out.len() {
if let Some(rc) = self.recv_stash.pop_front() {
out[filled] = rc;
filled += 1;
} else {
break;
}
}
if filled >= out.len() {
return Poll::Ready(Ok(filled));
}
loop {
let mut wc_buf = [WorkCompletion::default(); 8];
let n = match self
.qp
.poll_recv_cq(cx, &mut self.recv_cq_state, &mut wc_buf)
{
Poll::Pending => {
if filled > 0 {
return Poll::Ready(Ok(filled));
}
return Poll::Pending;
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Ready(Ok(n)) => n,
};
let mut got_data = false;
for wc in &wc_buf[..n] {
if !wc.is_success() {
self.peer_disconnected = true;
return Poll::Ready(Err(crate::Error::WorkCompletion {
status: wc.status_raw(),
vendor_err: wc.vendor_err(),
}));
}
match wc.opcode() {
WcOpcode::RecvRdmaWithImm => {
let imm = wc.imm_data();
let offset = (imm >> 16) as usize;
let length = (imm & 0xFFFF) as usize;
if offset >= self.recv_ring.capacity
|| (length > 0 && offset + length > self.recv_ring.capacity)
{
self.peer_disconnected = true;
return Poll::Ready(Err(crate::Error::InvalidArg(
"recv ring offset/length out of bounds".into(),
)));
}
if length == 0 {
let pad_len = self.recv_ring.capacity - offset;
if self.repost_doorbell().is_err() {
self.peer_disconnected = true;
return Poll::Ready(Err(crate::Error::InvalidArg(
"repost_doorbell failed".into(),
)));
}
let seq = self.recv_arrival_seq;
self.recv_arrival_seq += 1;
let slot = seq % self.max_outstanding;
self.slot_lengths[slot] = pad_len;
let contiguous = self.recv_tracker.release(slot);
if contiguous > 0 {
self.advance_recv_head(contiguous);
}
continue;
}
let mut virt_idx = self.next_virt_idx;
let mut found = false;
for _ in 0..self.max_outstanding {
if self.virtual_idx_map[virt_idx].is_none() {
found = true;
break;
}
virt_idx = (virt_idx + 1) % self.max_outstanding;
}
if !found {
self.peer_disconnected = true;
return Poll::Ready(Err(crate::Error::InvalidArg(
"recv virtual index map exhausted".into(),
)));
}
let seq = self.recv_arrival_seq;
self.recv_arrival_seq += 1;
self.virtual_idx_map[virt_idx] = Some((offset, length, seq));
self.slot_lengths[seq % self.max_outstanding] = length;
self.next_virt_idx = (virt_idx + 1) % self.max_outstanding;
let rc = RecvCompletion {
buf_idx: virt_idx,
byte_len: length,
};
if filled < out.len() {
out[filled] = rc;
filled += 1;
} else {
self.recv_stash.push_back(rc);
}
got_data = true;
}
_ => {
}
}
}
if filled > 0 || got_data {
return Poll::Ready(Ok(filled));
}
}
}
fn recv_buf(&self, buf_idx: usize) -> &[u8] {
let (offset, length, _seq) =
self.virtual_idx_map[buf_idx].expect("recv_buf called with invalid buf_idx");
&self.recv_ring.mr.as_slice()[offset..offset + length]
}
fn repost_recv(&mut self, buf_idx: usize) -> crate::Result<()> {
let (_offset, _length, arrival_seq) = self.virtual_idx_map[buf_idx]
.take()
.expect("repost_recv called with invalid buf_idx");
self.repost_doorbell()?;
let slot = arrival_seq % self.max_outstanding;
let contiguous = self.recv_tracker.release(slot);
if contiguous > 0 {
self.advance_recv_head(contiguous);
}
Ok(())
}
fn poll_disconnect(&mut self, cx: &mut Context<'_>) -> bool {
if self.peer_disconnected {
return true;
}
loop {
match self.cm_async_fd.poll_read_ready(cx) {
Poll::Ready(Ok(mut guard)) => {
guard.clear_ready();
if self.check_cm_event() {
return true;
}
}
Poll::Pending => {
return false;
}
Poll::Ready(Err(_)) => {
self.peer_disconnected = true;
return true;
}
}
}
}
fn disconnect(&mut self) -> crate::Result<()> {
if !self.disconnected {
self.cm_id.disconnect()?;
self.disconnected = true;
}
Ok(())
}
fn local_addr(&self) -> Option<SocketAddr> {
self.cm_id.local_addr()
}
fn peer_addr(&self) -> Option<SocketAddr> {
self.cm_id.peer_addr()
}
}
impl Drop for ReadRingTransport {
fn drop(&mut self) {
if !self.disconnected {
let _ = self.cm_id.disconnect();
}
let mut wc = [WorkCompletion::default(); 16];
loop {
match self.qp.send_cq().cq().poll(&mut wc) {
Ok(0) | Err(_) => break,
Ok(_) => continue,
}
}
loop {
match self.qp.recv_cq().cq().poll(&mut wc) {
Ok(0) | Err(_) => break,
Ok(_) => continue,
}
}
}
}
impl TransportBuilder for ReadRingConfig {
type Transport = ReadRingTransport;
async fn connect(&self, addr: &SocketAddr) -> crate::Result<ReadRingTransport> {
ReadRingTransport::connect(addr, self.clone()).await
}
async fn accept(&self, listener: &AsyncCmListener) -> crate::Result<ReadRingTransport> {
ReadRingTransport::accept(listener, self.clone()).await
}
}