use crate::{
completion_queue::{WCError, WorkCompletion, WorkRequestId},
context::{check_dev_cap, Context},
cq_event_listener::{CQEventListener, LmrInners},
error_utilities::{log_last_os_err, log_ret_last_os_err, log_ret_last_os_err_with_note},
gid::Gid,
impl_from_buidler_error_for_another, impl_into_io_error,
memory_region::{
local::{LocalMrReadAccess, LocalMrWriteAccess, RwLocalMrInner},
remote::{RemoteMrReadAccess, RemoteMrWriteAccess},
},
protection_domain::ProtectionDomain,
work_request::{RecvWr, SendWr},
DEFAULT_ACCESS,
};
use clippy_utilities::Cast;
use derive_builder::Builder;
use futures::{ready, Future, FutureExt};
use getset::{Getters, MutGetters, Setters};
use parking_lot::RwLock;
use rdma_sys::{
ibv_access_flags, ibv_ah_attr, ibv_cq, ibv_destroy_qp, ibv_global_route, ibv_modify_qp,
ibv_mtu, ibv_post_recv, ibv_post_send, ibv_qp, ibv_qp_attr, ibv_qp_attr_mask, ibv_qp_init_attr,
ibv_qp_state, ibv_query_qp, ibv_recv_wr, ibv_send_wr, ibv_srq,
};
use serde::{Deserialize, Serialize};
use std::{
fmt::Debug,
io,
pin::Pin,
ptr::{self, NonNull},
sync::Arc,
task::Poll,
time::Duration,
};
use tokio::{
sync::mpsc,
time::{sleep, Sleep},
};
use tracing::debug;
pub(crate) static MAX_SEND_WR: u32 = 10;
pub(crate) static MAX_RECV_WR: u32 = 10;
pub(crate) static MAX_SEND_SGE: u32 = 5;
pub(crate) static MAX_RECV_SGE: u32 = 5;
pub(crate) static MAX_INLINE_DATA: u32 = 0;
pub(crate) static SQ_SIG_ALL: i32 = 0_i32;
pub(crate) static DEFAULT_PORT_NUM: u8 = 1;
pub(crate) static DEFAULT_GID_INDEX: usize = 1;
pub(crate) static DEFAULT_PKEY_INDEX: u16 = 0;
pub(crate) static DEFAULT_FLOW_LABEL: u32 = 0;
pub(crate) static DEFAULT_HOP_LIMIT: u8 = 0xff;
pub(crate) static DEFAULT_TRAFFIC_CLASS: u8 = 0;
pub(crate) static DEFAULT_SERVICE_LEVEL: u8 = 0;
pub(crate) static DEFAULT_SRC_PATH_BITS: u8 = 0;
pub(crate) static DEFAULT_STATIC_RATE: u8 = 0;
pub(crate) static DEFAULT_IS_GLOBAL: u8 = 1;
pub(crate) static DEFAULT_RQ_PSN: u32 = 0;
pub(crate) static DEFAULT_MAX_DEST_RD_ATOMIC: u8 = 1;
pub(crate) static DEFAULT_MIN_RNR_TIMER: u8 = 0x12;
pub(crate) static DEFAULT_MTU: MTU = MTU::MTU512;
pub(crate) static DEFAULT_TIMEOUT: u8 = 0x12;
pub(crate) static DEFAULT_RETRY_CNT: u8 = 6;
pub(crate) static DEFAULT_RNR_RETRY: u8 = 6;
pub(crate) static DEFAULT_SQ_PSN: u32 = 0;
pub(crate) static DEFAULT_MAX_RD_ATOMIC: u8 = 1;
#[derive(Debug, Clone, Copy, Getters, Setters, Builder)]
#[builder(derive(Debug, Copy))]
#[getset(set, get = "pub")]
pub(crate) struct QueuePairCap {
#[builder(default = "MAX_SEND_WR")]
max_send_wr: u32,
#[builder(default = "MAX_RECV_WR")]
max_recv_wr: u32,
#[builder(default = "MAX_SEND_SGE")]
max_send_sge: u32,
#[builder(default = "MAX_RECV_SGE")]
max_recv_sge: u32,
#[builder(default = "MAX_INLINE_DATA")]
max_inline_data: u32,
}
impl QueuePairCap {
pub(crate) fn check_dev_qp_cap(&self, ctx: &Context) -> io::Result<()> {
let dev_attr = ctx.dev_attr();
check_dev_cap(&self.max_recv_sge, &dev_attr.max_sge.cast(), "max_recv_sge")?;
check_dev_cap(&self.max_send_sge, &dev_attr.max_sge.cast(), "max_send_sge")?;
check_dev_cap(&self.max_recv_wr, &dev_attr.max_qp_wr.cast(), "max_recv_wr")?;
check_dev_cap(&self.max_send_wr, &dev_attr.max_qp_wr.cast(), "max_send_wr")?;
Ok(())
}
}
impl_from_buidler_error_for_another!(QueuePairCapBuilderError, QueuePairInitAttrBuilderError);
impl_into_io_error!(QueuePairCapBuilderError);
#[derive(Debug, Clone, Copy, MutGetters, Getters, Setters, Builder)]
#[builder(derive(Debug, Copy))]
#[getset(set, get = "pub")]
pub(crate) struct QueuePairInitAttr {
#[builder(default = "None")]
qp_context: Option<*mut libc::c_void>,
send_cq: *mut ibv_cq,
recv_cq: *mut ibv_cq,
#[builder(default = "None")]
srq: Option<*mut ibv_srq>,
#[builder(
field(type = "QueuePairCapBuilder", build = "self.qp_cap.build()?"),
setter(custom)
)]
#[getset(get_mut = "pub")]
qp_cap: QueuePairCap,
#[builder(default = "rdma_sys::ibv_qp_type::IBV_QPT_RC")]
qp_type: u32,
#[builder(default = "SQ_SIG_ALL")]
sq_sig_all: i32,
#[builder(default = "*DEFAULT_ACCESS")]
access: ibv_access_flags,
#[builder(default = "DEFAULT_PKEY_INDEX")]
pkey_index: u16,
}
impl_into_io_error!(QueuePairInitAttrBuilderError);
impl QueuePairInitAttrBuilder {
pub(crate) fn qp_cap(&mut self) -> &mut QueuePairCapBuilder {
&mut self.qp_cap
}
}
impl From<QueuePairInitAttr> for ibv_qp_init_attr {
#[inline]
fn from(s: QueuePairInitAttr) -> Self {
let mut qp_init_attr = unsafe { std::mem::zeroed::<ibv_qp_init_attr>() };
qp_init_attr.qp_context = s
.qp_context
.unwrap_or(ptr::null_mut::<libc::c_void>().cast());
qp_init_attr.send_cq = s.send_cq;
qp_init_attr.recv_cq = s.recv_cq;
qp_init_attr.srq = s.srq.unwrap_or(ptr::null_mut::<ibv_srq>().cast());
qp_init_attr.cap.max_send_wr = s.qp_cap.max_send_wr;
qp_init_attr.cap.max_recv_wr = s.qp_cap.max_recv_wr;
qp_init_attr.cap.max_send_sge = s.qp_cap.max_send_sge;
qp_init_attr.cap.max_recv_sge = s.qp_cap.max_recv_sge;
qp_init_attr.cap.max_inline_data = s.qp_cap.max_inline_data;
qp_init_attr.qp_type = s.qp_type;
qp_init_attr.sq_sig_all = s.sq_sig_all;
qp_init_attr
}
}
#[derive(Copy, Clone, PartialEq, Eq, Debug, Serialize, Deserialize, Getters, Setters, Builder)]
#[getset(set, get = "pub")]
pub struct QueuePairEndpoint {
qp_num: u32,
lid: u16,
gid: Gid,
}
impl_into_io_error!(QueuePairEndpointBuilderError);
#[derive(Debug, Clone, Copy, Getters, Setters, Builder)]
#[builder(derive(Debug, Copy))]
#[getset(set, get = "pub")]
pub(crate) struct AddressHandler {
#[builder(
field(type = "GlobalRouteHeaderBuilder", build = "self.grh.build()?"),
setter(custom)
)]
grh: GlobalRouteHeader,
dest_lid: u16,
#[builder(default = "DEFAULT_SERVICE_LEVEL")]
service_level: u8,
#[builder(default = "DEFAULT_SRC_PATH_BITS")]
src_path_bits: u8,
#[builder(default = "DEFAULT_STATIC_RATE")]
static_rate: u8,
#[builder(default = "DEFAULT_IS_GLOBAL")]
is_global: u8,
#[builder(default = "DEFAULT_PORT_NUM")]
port_num: u8,
}
impl AddressHandlerBuilder {
pub(crate) fn grh(&mut self) -> &mut GlobalRouteHeaderBuilder {
&mut self.grh
}
pub(crate) fn get_port_num(&self) -> Option<u8> {
self.port_num
}
}
impl_into_io_error!(SQAttrBuilderError);
impl From<AddressHandler> for ibv_ah_attr {
#[inline]
fn from(ah: AddressHandler) -> Self {
let mut ah_attr = unsafe { std::mem::zeroed::<ibv_ah_attr>() };
ah_attr.dlid = ah.dest_lid;
ah_attr.sl = ah.service_level;
ah_attr.src_path_bits = ah.src_path_bits;
ah_attr.static_rate = ah.static_rate;
ah_attr.is_global = ah.is_global;
ah_attr.port_num = ah.port_num;
ah_attr.grh = ah.grh.into();
ah_attr
}
}
impl_from_buidler_error_for_another!(AddressHandlerBuilderError, RQAttrBuilderError);
#[derive(Debug, Clone, Copy, Getters, Setters, Builder)]
#[builder(derive(Debug, Copy))]
#[getset(set, get = "pub")]
pub(crate) struct GlobalRouteHeader {
dgid: Gid,
#[builder(default = "DEFAULT_FLOW_LABEL")]
flow_label: u32,
#[builder(default = "DEFAULT_GID_INDEX.cast()")]
sgid_index: u8,
#[builder(default = "DEFAULT_HOP_LIMIT")]
hop_limit: u8,
#[builder(default = "DEFAULT_TRAFFIC_CLASS")]
traffic_class: u8,
}
impl From<GlobalRouteHeader> for ibv_global_route {
#[inline]
fn from(grh: GlobalRouteHeader) -> Self {
let mut ibv_grh = unsafe { std::mem::zeroed::<ibv_global_route>() };
ibv_grh.dgid = grh.dgid.into();
ibv_grh.flow_label = grh.flow_label;
ibv_grh.hop_limit = grh.hop_limit;
ibv_grh.sgid_index = grh.sgid_index;
ibv_grh.traffic_class = grh.traffic_class;
ibv_grh
}
}
impl_from_buidler_error_for_another!(GlobalRouteHeaderBuilderError, AddressHandlerBuilderError);
#[derive(Debug, Clone, Copy)]
pub enum MTU {
MTU256,
MTU512,
MTU1024,
MTU2048,
MTU4096,
}
impl From<MTU> for u32 {
#[inline]
fn from(mtu: MTU) -> Self {
match mtu {
MTU::MTU256 => ibv_mtu::IBV_MTU_256,
MTU::MTU512 => ibv_mtu::IBV_MTU_512,
MTU::MTU1024 => ibv_mtu::IBV_MTU_1024,
MTU::MTU2048 => ibv_mtu::IBV_MTU_2048,
MTU::MTU4096 => ibv_mtu::IBV_MTU_4096,
}
}
}
#[derive(Debug, Clone, Copy, Getters, Setters, Builder)]
#[builder(derive(Debug, Copy))]
#[getset(set, get = "pub")]
pub(crate) struct RQAttr {
#[builder(default = "DEFAULT_MTU")]
mtu: MTU,
dest_qp_number: u32,
#[builder(
field(
type = "AddressHandlerBuilder",
build = "self.address_handler.build()?"
),
setter(custom)
)]
address_handler: AddressHandler,
#[builder(default = "DEFAULT_RQ_PSN")]
rq_psn: u32,
#[builder(default = "DEFAULT_MAX_DEST_RD_ATOMIC")]
max_dest_rd_atomic: u8,
#[builder(default = "DEFAULT_MIN_RNR_TIMER")]
min_rnr_timer: u8,
}
impl_into_io_error!(RQAttrBuilderError);
impl RQAttrBuilder {
pub(crate) fn address_handler(&mut self) -> &mut AddressHandlerBuilder {
&mut self.address_handler
}
pub(crate) fn reset_remote_info(&mut self, remote: &QueuePairEndpoint) {
let _ = self
.dest_qp_number(*remote.qp_num())
.address_handler()
.dest_lid(*remote.lid())
.grh()
.dgid(*remote.gid());
}
pub(crate) fn get_port_num(&self) -> u8 {
self.address_handler
.get_port_num()
.unwrap_or(DEFAULT_PORT_NUM)
}
pub(crate) fn get_sgid_index(&self) -> u8 {
self.address_handler
.grh
.sgid_index
.unwrap_or_else(|| DEFAULT_GID_INDEX.cast())
}
}
#[derive(Debug, Clone, Copy, Getters, Setters, Builder)]
#[builder(derive(Debug, Copy))]
#[getset(set, get = "pub")]
pub(crate) struct SQAttr {
#[builder(default = "DEFAULT_TIMEOUT")]
timeout: u8,
#[builder(default = "DEFAULT_RETRY_CNT")]
retry_cnt: u8,
#[builder(default = "DEFAULT_RNR_RETRY")]
rnr_retry: u8,
#[builder(default = "DEFAULT_SQ_PSN")]
sq_psn: u32,
#[builder(default = "DEFAULT_MAX_RD_ATOMIC")]
max_rd_atomic: u8,
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueuePairState {
Reset,
Init,
ReadyToRecv,
ReadyToSend,
SQDrain,
SQErr,
Err,
Unknown,
}
impl From<u32> for QueuePairState {
#[inline]
fn from(num: u32) -> Self {
if num == ibv_qp_state::IBV_QPS_RTS {
Self::ReadyToSend
} else if num == ibv_qp_state::IBV_QPS_RTR {
Self::ReadyToRecv
} else if num == ibv_qp_state::IBV_QPS_INIT {
Self::Init
} else if num == ibv_qp_state::IBV_QPS_ERR {
Self::Err
} else if num == ibv_qp_state::IBV_QPS_RESET {
Self::Reset
} else if num == ibv_qp_state::IBV_QPS_UNKNOWN {
Self::Unknown
} else if num == ibv_qp_state::IBV_QPS_SQE {
Self::SQErr
} else {
Self::SQDrain
}
}
}
#[derive(Debug, Getters, Setters, Builder)]
#[getset(set, get = "pub")]
pub(crate) struct QueuePair {
pd: Arc<ProtectionDomain>,
cq_event_listener: Arc<CQEventListener>,
inner_qp: NonNull<ibv_qp>,
cur_state: Arc<RwLock<QueuePairState>>,
}
impl_into_io_error!(QueuePairBuilderError);
impl QueuePair {
pub(crate) fn as_ptr(&self) -> *mut ibv_qp {
self.inner_qp.as_ptr()
}
pub(crate) fn query_attrs(
&self,
mask: ibv_qp_attr_mask,
) -> io::Result<(ibv_qp_attr, ibv_qp_init_attr)> {
let mut qp_attr = unsafe { std::mem::zeroed::<ibv_qp_attr>() };
let mut qp_init_attr = unsafe { std::mem::zeroed::<ibv_qp_init_attr>() };
let errno = unsafe {
ibv_query_qp(
self.as_ptr(),
&mut qp_attr,
mask.0.cast(),
&mut qp_init_attr,
)
};
if errno != 0_i32 {
return Err(log_ret_last_os_err_with_note("ibv_query_qp failed"));
}
Ok((qp_attr, qp_init_attr))
}
pub(crate) fn query_state(&self) -> io::Result<QueuePairState> {
let mask = ibv_qp_attr_mask::IBV_QP_STATE;
let res = self.query_attrs(mask)?;
Ok(res.0.qp_state.into())
}
pub(crate) fn endpoint(&self) -> QueuePairEndpoint {
QueuePairEndpoint {
qp_num: unsafe { (*self.as_ptr()).qp_num },
lid: self.pd.ctx.get_lid(),
gid: *self.pd.ctx.gid(),
}
}
pub(crate) fn modify_to_init(
&mut self,
flag: ibv_access_flags,
port_num: u8,
pkey_index: u16,
) -> io::Result<()> {
let mut attr = unsafe { std::mem::zeroed::<ibv_qp_attr>() };
attr.qp_state = ibv_qp_state::IBV_QPS_INIT;
attr.pkey_index = pkey_index;
attr.port_num = port_num;
attr.qp_access_flags = flag.0;
let flags = ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
| ibv_qp_attr_mask::IBV_QP_STATE
| ibv_qp_attr_mask::IBV_QP_PORT
| ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
let errno = unsafe { ibv_modify_qp(self.as_ptr(), &mut attr, flags.0.cast()) };
if errno != 0_i32 {
return Err(log_ret_last_os_err());
}
*self.cur_state.write() = QueuePairState::Init;
Ok(())
}
pub(crate) fn modify_to_rtr(&self, rq_attr: RQAttr) -> io::Result<()> {
let mut qp_attr = unsafe { std::mem::zeroed::<ibv_qp_attr>() };
qp_attr.qp_state = ibv_qp_state::IBV_QPS_RTR;
qp_attr.path_mtu = rq_attr.mtu.into();
qp_attr.dest_qp_num = rq_attr.dest_qp_number;
qp_attr.rq_psn = rq_attr.rq_psn;
qp_attr.max_dest_rd_atomic = rq_attr.max_dest_rd_atomic;
qp_attr.min_rnr_timer = rq_attr.min_rnr_timer;
qp_attr.ah_attr = rq_attr.address_handler.into();
let flags = ibv_qp_attr_mask::IBV_QP_STATE
| ibv_qp_attr_mask::IBV_QP_AV
| ibv_qp_attr_mask::IBV_QP_PATH_MTU
| ibv_qp_attr_mask::IBV_QP_DEST_QPN
| ibv_qp_attr_mask::IBV_QP_RQ_PSN
| ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
| ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
let errno = unsafe { ibv_modify_qp(self.as_ptr(), &mut qp_attr, flags.0.cast()) };
if errno != 0_i32 {
return Err(log_ret_last_os_err());
}
*self.cur_state.write() = QueuePairState::ReadyToRecv;
Ok(())
}
pub(crate) fn modify_to_rts(&self, sq_attr: SQAttr) -> io::Result<()> {
let mut attr = unsafe { std::mem::zeroed::<ibv_qp_attr>() };
attr.qp_state = ibv_qp_state::IBV_QPS_RTS;
attr.timeout = sq_attr.timeout;
attr.retry_cnt = sq_attr.retry_cnt;
attr.rnr_retry = sq_attr.rnr_retry;
attr.sq_psn = sq_attr.sq_psn;
attr.max_rd_atomic = sq_attr.max_rd_atomic;
let flags = ibv_qp_attr_mask::IBV_QP_STATE
| ibv_qp_attr_mask::IBV_QP_TIMEOUT
| ibv_qp_attr_mask::IBV_QP_RETRY_CNT
| ibv_qp_attr_mask::IBV_QP_RNR_RETRY
| ibv_qp_attr_mask::IBV_QP_SQ_PSN
| ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
let errno = unsafe { ibv_modify_qp(self.as_ptr(), &mut attr, flags.0.cast()) };
if errno != 0_i32 {
return Err(log_ret_last_os_err());
}
*self.cur_state.write() = QueuePairState::ReadyToSend;
Ok(())
}
fn submit_send<LR>(&self, lms: &[&LR], wr_id: WorkRequestId, imm: Option<u32>) -> io::Result<()>
where
LR: LocalMrReadAccess,
{
let mut bad_wr = std::ptr::null_mut::<ibv_send_wr>();
let mut send_attr = SendWr::new_send(lms, wr_id, imm);
self.cq_event_listener.cq.req_notify(false)?;
for lm in lms {
debug!(
"post_send addr {}, len {}, lkey {} wrid: {}",
lm.addr(),
lm.length(),
unsafe { lm.lkey_unchecked() },
send_attr.as_ref().wr_id,
);
}
let errno = unsafe { ibv_post_send(self.as_ptr(), send_attr.as_mut(), &mut bad_wr) };
if errno != 0_i32 {
return Err(log_ret_last_os_err());
}
Ok(())
}
fn submit_receive<LW>(&self, lms: &[&mut LW], wr_id: WorkRequestId) -> io::Result<()>
where
LW: LocalMrWriteAccess,
{
let mut recv_attr = RecvWr::new_recv(lms, wr_id);
let mut bad_wr = std::ptr::null_mut::<ibv_recv_wr>();
self.cq_event_listener.cq.req_notify(false)?;
for lm in lms {
debug!(
"post_recv addr {}, len {}, lkey {} wrid: {}",
lm.addr(),
lm.length(),
unsafe { lm.lkey_unchecked() },
recv_attr.as_ref().wr_id,
);
}
let errno = unsafe { ibv_post_recv(self.as_ptr(), recv_attr.as_mut(), &mut bad_wr) };
if errno != 0_i32 {
return Err(log_ret_last_os_err());
}
Ok(())
}
fn submit_read<LW, RR>(&self, lms: &[&mut LW], rm: &RR, wr_id: WorkRequestId) -> io::Result<()>
where
LW: LocalMrWriteAccess,
RR: RemoteMrReadAccess,
{
let mut bad_wr = std::ptr::null_mut::<ibv_send_wr>();
let mut send_attr = SendWr::new_read(lms, wr_id, rm);
self.cq_event_listener.cq.req_notify(false)?;
for lm in lms {
debug!(
"post_send addr {}, len {}, lkey {} wrid: {}",
lm.addr(),
lm.length(),
unsafe { lm.lkey_unchecked() },
send_attr.as_ref().wr_id,
);
}
let errno = unsafe { ibv_post_send(self.as_ptr(), send_attr.as_mut(), &mut bad_wr) };
if errno != 0_i32 {
return Err(log_ret_last_os_err());
}
Ok(())
}
fn submit_write<LR, RW>(
&self,
lms: &[&LR],
rm: &mut RW,
wr_id: WorkRequestId,
imm: Option<u32>,
) -> io::Result<()>
where
LR: LocalMrReadAccess,
RW: RemoteMrWriteAccess,
{
let mut bad_wr = std::ptr::null_mut::<ibv_send_wr>();
let mut send_attr = SendWr::new_write(lms, wr_id, rm, imm);
self.cq_event_listener.cq.req_notify(false)?;
for lm in lms {
debug!(
"post_send addr {}, len {}, lkey_unchecked {} wrid: {}",
lm.addr(),
lm.length(),
unsafe { lm.lkey_unchecked() },
send_attr.as_ref().wr_id,
);
}
let errno = unsafe { ibv_post_send(self.as_ptr(), send_attr.as_mut(), &mut bad_wr) };
if errno != 0_i32 {
return Err(log_ret_last_os_err());
}
Ok(())
}
fn submit_cas<LR, RW>(
&self,
old_value: u64,
new_value: u64,
buf: &LR,
rm: &mut RW,
wr_id: WorkRequestId,
) -> io::Result<()>
where
LR: LocalMrReadAccess,
RW: RemoteMrWriteAccess,
{
let mut cas_wr = SendWr::new_cas(old_value, new_value, buf, rm, wr_id);
let mut bad_wr = std::ptr::null_mut::<ibv_send_wr>();
self.cq_event_listener.cq.req_notify(false)?;
let errno = unsafe { ibv_post_send(self.as_ptr(), cas_wr.as_mut(), &mut bad_wr) };
if errno != 0_i32 {
return Err(log_ret_last_os_err());
}
Ok(())
}
pub(crate) fn send_sge<'a, LR>(
self: &Arc<Self>,
lms: &'a [&'a LR],
imm: Option<u32>,
) -> QueuePairOps<QPSend<'a, LR>>
where
LR: LocalMrReadAccess,
{
let send = QPSend::new(lms, imm);
QueuePairOps::new(Arc::<Self>::clone(self), send, get_lmr_inners(lms))
}
#[cfg(feature = "raw")]
pub(crate) async fn send_sge_raw<'a, LR>(
self: &Arc<Self>,
lms: &'a [&'a LR],
imm: Option<u32>,
) -> io::Result<()>
where
LR: LocalMrReadAccess,
{
let (wr_id, mut resp_rx) = self
.cq_event_listener
.register_for_read(&get_lmr_inners(lms))?;
let len: usize = lms.iter().map(|lm| lm.length()).sum();
self.submit_send(lms, wr_id, imm)?;
resp_rx
.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "agent is dropped"))?
.result()
.map(|sz| debug!("post size: {sz}, mr len: {len}"))
.map_err(Into::into)
}
pub(crate) fn receive_sge<'a, LW>(
self: &Arc<Self>,
lms: &'a [&'a mut LW],
) -> QueuePairOps<QPRecv<'a, LW>>
where
LW: LocalMrWriteAccess,
{
let recv = QPRecv::new(lms);
QueuePairOps::new(Arc::<Self>::clone(self), recv, get_mut_lmr_inners(lms))
}
#[cfg(feature = "raw")]
pub(crate) async fn receive_sge_raw<'a, LW>(
self: &Arc<Self>,
lms: &'a [&'a mut LW],
) -> io::Result<Option<u32>>
where
LW: LocalMrWriteAccess,
{
let (wr_id, mut resp_rx) = self
.cq_event_listener
.register_for_write(&get_mut_lmr_inners(lms))?;
let len: usize = lms.iter().map(|lm| lm.length()).sum();
self.submit_receive(lms, wr_id)?;
resp_rx
.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "cq_event_listener is dropped"))?
.result_with_imm()
.map(|(sz, imm)| {
debug!("post size: {sz}, mr len: {len}");
imm
})
.map_err(Into::into)
}
#[cfg(feature = "exp")]
pub(crate) async fn receive_sge_fn<'a, LW, F>(
self: &Arc<Self>,
lms: &'a [&'a mut LW],
func: F,
) -> io::Result<Option<u32>>
where
LW: LocalMrWriteAccess,
F: FnOnce(),
{
let (wr_id, mut resp_rx) = self
.cq_event_listener
.register_for_write(&get_mut_lmr_inners(lms))?;
let len: usize = lms.iter().map(|lm| lm.length()).sum();
self.submit_receive(lms, wr_id)?;
func();
resp_rx
.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "cq_event_listener is dropped"))?
.result_with_imm()
.map(|(sz, imm)| {
debug!("receivede size {sz}, lms len {len}");
imm
})
.map_err(Into::into)
}
pub(crate) async fn read_sge<LW, RR>(&self, lms: &[&mut LW], rm: &RR) -> io::Result<()>
where
LW: LocalMrWriteAccess,
RR: RemoteMrReadAccess,
{
let (wr_id, mut resp_rx) = self
.cq_event_listener
.register_for_write(&get_mut_lmr_inners(lms))?;
let len: usize = lms.iter().map(|lm| lm.length()).sum();
self.submit_read(lms, rm, wr_id)?;
resp_rx
.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "agent is dropped"))?
.result()
.map(|sz| debug!("post size: {sz}, mr len: {len}"))
.map_err(Into::into)
}
pub(crate) async fn write_sge<LR, RW>(
&self,
lms: &[&LR],
rm: &mut RW,
imm: Option<u32>,
) -> io::Result<()>
where
LR: LocalMrReadAccess,
RW: RemoteMrWriteAccess,
{
let (wr_id, mut resp_rx) = self
.cq_event_listener
.register_for_read(&get_lmr_inners(lms))?;
let len: usize = lms.iter().map(|lm| lm.length()).sum();
self.submit_write(lms, rm, wr_id, imm)?;
resp_rx
.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "agent is dropped"))?
.result()
.map(|sz| debug!("post size: {sz}, mr len: {len}"))
.map_err(Into::into)
}
pub(crate) async fn atomic_cas<LR, RW>(
&self,
old_value: u64,
new_value: u64,
buf: &LR,
rm: &mut RW,
) -> io::Result<()>
where
LR: LocalMrReadAccess,
RW: RemoteMrWriteAccess,
{
let (wr_id, mut resp_rx) = self
.cq_event_listener
.register_for_read(&get_lmr_inners(&[buf]))?;
self.submit_cas(old_value, new_value, buf, rm, wr_id)?;
resp_rx
.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "agent is dropped"))?
.result()
.map(|sz| assert_eq!(sz, 8))
.map_err(Into::into)
}
pub(crate) async fn send<LR>(self: &Arc<Self>, lm: &LR) -> io::Result<()>
where
LR: LocalMrReadAccess,
{
self.send_sge(&[lm], None).await
}
pub(crate) async fn read<LW, RR>(&self, lm: &mut LW, rm: &RR) -> io::Result<()>
where
LW: LocalMrWriteAccess,
RR: RemoteMrReadAccess,
{
self.read_sge(&[lm], rm).await
}
pub(crate) async fn write<LR, RW>(
&self,
lm: &LR,
rm: &mut RW,
imm: Option<u32>,
) -> io::Result<()>
where
LR: LocalMrReadAccess,
RW: RemoteMrWriteAccess,
{
self.write_sge(&[lm], rm, imm).await
}
}
unsafe impl Sync for QueuePair {}
unsafe impl Send for QueuePair {}
unsafe impl Sync for QueuePairInitAttrBuilder {}
unsafe impl Send for QueuePairInitAttrBuilder {}
impl Drop for QueuePair {
fn drop(&mut self) {
let errno = unsafe { ibv_destroy_qp(self.as_ptr()) };
if errno != 0_i32 {
log_last_os_err();
}
}
}
fn get_lmr_inners<LR>(lms: &[&LR]) -> LmrInners
where
LR: LocalMrReadAccess,
{
lms.iter()
.map(|lm| Arc::<RwLocalMrInner>::clone(lm.get_inner()))
.collect()
}
fn get_mut_lmr_inners<LR>(lms: &[&mut LR]) -> LmrInners
where
LR: LocalMrReadAccess,
{
let imlms: Vec<&LR> = lms.iter().map(|lm| &**lm).collect();
get_lmr_inners(&imlms)
}
static RESUBMIT_DELAY: Duration = Duration::from_secs(1);
pub(crate) trait QueuePairOp {
type Output;
fn submit(&self, qp: &QueuePair, wr_id: WorkRequestId) -> io::Result<()>;
fn should_resubmit(&self, e: &io::Error) -> bool;
fn result(&self, wc: WorkCompletion) -> Result<Self::Output, WCError>;
}
#[derive(Debug)]
pub(crate) struct QPSend<'lm, LR>
where
LR: LocalMrReadAccess,
{
lms: &'lm [&'lm LR],
len: usize,
imm: Option<u32>,
}
impl<'lm, LR> QPSend<'lm, LR>
where
LR: LocalMrReadAccess,
{
fn new(lms: &'lm [&'lm LR], imm: Option<u32>) -> Self
where
LR: LocalMrReadAccess,
{
Self {
len: lms.iter().map(|lm| lm.length()).sum(),
lms,
imm,
}
}
}
impl<LR> QueuePairOp for QPSend<'_, LR>
where
LR: LocalMrReadAccess,
{
type Output = ();
fn submit(&self, qp: &QueuePair, wr_id: WorkRequestId) -> io::Result<()> {
qp.submit_send(self.lms, wr_id, self.imm)
}
fn should_resubmit(&self, e: &io::Error) -> bool {
matches!(e.kind(), io::ErrorKind::OutOfMemory)
}
fn result(&self, wc: WorkCompletion) -> Result<Self::Output, WCError> {
wc.result()
.map(|sz| debug!("post size: {sz}, mr len: {}", self.len))
}
}
#[derive(Debug)]
pub(crate) struct QPRecv<'lm, LW>
where
LW: LocalMrWriteAccess,
{
lms: &'lm [&'lm mut LW],
}
impl<'lm, LW> QPRecv<'lm, LW>
where
LW: LocalMrWriteAccess,
{
fn new(lms: &'lm [&'lm mut LW]) -> Self {
Self { lms }
}
}
impl<LW> QueuePairOp for QPRecv<'_, LW>
where
LW: LocalMrWriteAccess,
{
type Output = (usize, Option<u32>);
fn submit(&self, qp: &QueuePair, wr_id: WorkRequestId) -> io::Result<()> {
qp.submit_receive(self.lms, wr_id)
}
fn should_resubmit(&self, e: &io::Error) -> bool {
matches!(e.kind(), io::ErrorKind::OutOfMemory)
}
fn result(&self, wc: WorkCompletion) -> Result<Self::Output, WCError> {
wc.result_with_imm()
}
}
#[derive(Debug)]
enum QueuePairOpsState {
Init(LmrInners),
Submit(WorkRequestId, Option<mpsc::Receiver<WorkCompletion>>),
PendingToResubmit(
Pin<Box<Sleep>>,
WorkRequestId,
Option<mpsc::Receiver<WorkCompletion>>,
),
Submitted(mpsc::Receiver<WorkCompletion>),
}
#[derive(Debug)]
pub(crate) struct QueuePairOps<Op: QueuePairOp + Unpin> {
qp: Arc<QueuePair>,
state: QueuePairOpsState,
op: Op,
}
impl<Op: QueuePairOp + Unpin> QueuePairOps<Op> {
fn new(qp: Arc<QueuePair>, op: Op, inners: LmrInners) -> Self {
Self {
qp,
state: QueuePairOpsState::Init(inners),
op,
}
}
}
impl<Op: QueuePairOp + Unpin> Future for QueuePairOps<Op> {
type Output = io::Result<Op::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let s = self.get_mut();
match s.state {
QueuePairOpsState::Init(ref inners) => {
let (wr_id, recv) = s.qp.cq_event_listener.register_for_write(inners)?;
s.state = QueuePairOpsState::Submit(wr_id, Some(recv));
Pin::new(s).poll(cx)
}
QueuePairOpsState::Submit(wr_id, ref mut recv) => {
if let Err(e) = s.op.submit(&s.qp, wr_id) {
if s.op.should_resubmit(&e) {
let sleep = Box::pin(sleep(RESUBMIT_DELAY));
s.state = QueuePairOpsState::PendingToResubmit(sleep, wr_id, recv.take());
} else {
tracing::error!("failed to submit the operation");
return Poll::Ready(Err(e));
}
} else {
match recv.take().ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "Bug in queue pair op poll")
}) {
Ok(recv) => s.state = QueuePairOpsState::Submitted(recv),
Err(e) => return Poll::Ready(Err(e)),
}
}
Pin::new(s).poll(cx)
}
QueuePairOpsState::PendingToResubmit(ref mut sleep, wr_id, ref mut recv) => {
ready!(sleep.poll_unpin(cx));
s.state = QueuePairOpsState::Submit(wr_id, recv.take());
Pin::new(s).poll(cx)
}
QueuePairOpsState::Submitted(ref mut recv) => {
Poll::Ready(match ready!(recv.poll_recv(cx)) {
Some(wc) => s.op.result(wc).map_err(Into::into),
None => Err(io::Error::new(
io::ErrorKind::Other,
"Wc receiver unexpect closed",
)),
})
}
}
}
}
pub(crate) fn builders_into_attrs(
mut recv_attr_builder: RQAttrBuilder,
send_attr_builder: SQAttrBuilder,
remote: &QueuePairEndpoint,
) -> io::Result<(RQAttr, SQAttr)> {
let _ = recv_attr_builder
.dest_qp_number(remote.qp_num)
.address_handler()
.dest_lid(*remote.lid())
.grh()
.dgid(*remote.gid());
let send_attr = send_attr_builder.build()?;
let recv_attr = recv_attr_builder.build()?;
Ok((recv_attr, send_attr))
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod dev_cap_check_tests {
use super::*;
#[test]
#[should_panic]
fn recv_sge_cap_overrun() {
let ctx = Context::open(None, 1, 1).unwrap();
let cap = QueuePairCapBuilder::default()
.max_recv_sge(u32::MAX)
.build()
.unwrap();
cap.check_dev_qp_cap(&ctx).unwrap();
}
#[test]
#[should_panic]
fn send_sge_cap_overrun() {
let ctx = Context::open(None, 1, 1).unwrap();
let cap = QueuePairCapBuilder::default()
.max_send_sge(u32::MAX)
.build()
.unwrap();
cap.check_dev_qp_cap(&ctx).unwrap();
}
#[test]
#[should_panic]
fn recv_wr_cap_overrun() {
let ctx = Context::open(None, 1, 1).unwrap();
let cap = QueuePairCapBuilder::default()
.max_recv_wr(u32::MAX)
.build()
.unwrap();
cap.check_dev_qp_cap(&ctx).unwrap();
}
#[test]
#[should_panic]
fn send_wr_cap_overrun() {
let ctx = Context::open(None, 1, 1).unwrap();
let cap = QueuePairCapBuilder::default()
.max_send_wr(u32::MAX)
.build()
.unwrap();
cap.check_dev_qp_cap(&ctx).unwrap();
}
}