use bitmask_enum::bitmask;
use rdma_mummy_sys::{
ibv_create_qp, ibv_create_qp_ex, ibv_data_buf, ibv_destroy_qp, ibv_modify_qp, ibv_post_recv, ibv_post_send, ibv_qp,
ibv_qp_attr, ibv_qp_attr_mask, ibv_qp_cap, ibv_qp_create_send_ops_flags, ibv_qp_ex, ibv_qp_init_attr,
ibv_qp_init_attr_ex, ibv_qp_init_attr_mask, ibv_qp_state, ibv_qp_to_qp_ex, ibv_qp_type, ibv_query_qp, ibv_recv_wr,
ibv_rx_hash_conf, ibv_send_flags, ibv_send_wr, ibv_sge, ibv_wr_abort, ibv_wr_complete, ibv_wr_opcode,
ibv_wr_rdma_read, ibv_wr_rdma_write, ibv_wr_rdma_write_imm, ibv_wr_send, ibv_wr_send_imm, ibv_wr_set_inline_data,
ibv_wr_set_inline_data_list, ibv_wr_set_sge, ibv_wr_set_sge_list, ibv_wr_start,
};
use std::sync::{Arc, LazyLock};
use std::{
fmt,
io::{self, IoSlice},
marker::PhantomData,
mem::MaybeUninit,
ptr::{null_mut, NonNull},
};
use super::{
address::{AddressHandleAttribute, Gid},
completion::{CompletionQueue, GenericCompletionQueue},
device_context::Mtu,
protection_domain::ProtectionDomain,
AccessFlags,
};
#[derive(Debug, thiserror::Error)]
#[error("failed to create queue pair")]
#[non_exhaustive]
pub struct CreateQueuePairError(#[from] pub CreateQueuePairErrorKind);
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
#[non_exhaustive]
pub enum CreateQueuePairErrorKind {
Ibverbs(#[from] io::Error),
}
#[derive(Debug, thiserror::Error)]
#[error("failed to query queue pair")]
#[non_exhaustive]
pub struct QueryQueuePairError(#[from] pub QueryQueuePairErrorKind);
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
#[non_exhaustive]
pub enum QueryQueuePairErrorKind {
Ibverbs(#[from] io::Error),
}
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
#[non_exhaustive]
pub struct ModifyQueuePairError(#[from] pub ModifyQueuePairErrorKind);
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum ModifyQueuePairErrorKind {
#[error("modify queue pair failed")]
Ibverbs(#[from] io::Error),
#[error("invalid transition from {cur_state:?} to {next_state:?}")]
InvalidTransition {
cur_state: QueuePairState,
next_state: QueuePairState,
source: io::Error,
},
#[error("invalid transition from {cur_state:?} to {next_state:?}, possible invalid masks {invalid:?}, possible needed masks {needed:?}")]
InvalidAttributeMask {
cur_state: QueuePairState,
next_state: QueuePairState,
invalid: QueuePairAttributeMask,
needed: QueuePairAttributeMask,
source: io::Error,
},
#[error("resolve route timed out, source gid index: {sgid_index}, destination gid: {gid}")]
ResolveRouteTimedout {
sgid_index: u8,
gid: Gid,
source: io::Error,
},
#[error("network unreachable, source gid index: {sgid_index}, destination gid: {gid}")]
NetworkUnreachable {
sgid_index: u8,
gid: Gid,
source: io::Error,
},
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum PostSendError {
#[error("post send failed")]
Ibverbs(#[from] io::Error),
#[error("invalid value provided in work request")]
InvalidWorkRequest(#[source] io::Error),
#[error("invalid value provided in queue pair")]
InvalidQueuePair(#[source] io::Error),
#[error("send queue is full or not enough resources to complete this operation")]
NotEnoughResources(#[source] io::Error),
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum PostRecvError {
#[error("post receive failed")]
Ibverbs(#[from] io::Error),
#[error("invalid value provided in work request")]
InvalidWorkRequest(#[source] io::Error),
#[error("invalid value provided in queue pair")]
InvalidQueuePair(#[source] io::Error),
#[error("receive queue is full or not enough resources to complete this operation")]
NotEnoughResources(#[source] io::Error),
}
#[repr(u32)]
#[derive(Debug, Clone, Copy)]
pub enum QueuePairType {
ReliableConnection = ibv_qp_type::IBV_QPT_RC,
UnreliableConnection = ibv_qp_type::IBV_QPT_UC,
UnreliableDatagram = ibv_qp_type::IBV_QPT_UD,
RawPacket = ibv_qp_type::IBV_QPT_RAW_PACKET,
ReliableConnectionExtendedSend = ibv_qp_type::IBV_QPT_XRC_SEND,
ReliableConnectionExtendedRecv = ibv_qp_type::IBV_QPT_XRC_RECV,
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum QueuePairState {
Reset = ibv_qp_state::IBV_QPS_RESET,
Init = ibv_qp_state::IBV_QPS_INIT,
ReadyToReceive = ibv_qp_state::IBV_QPS_RTR,
ReadyToSend = ibv_qp_state::IBV_QPS_RTS,
SendQueueDrain = ibv_qp_state::IBV_QPS_SQD,
SendQueueError = ibv_qp_state::IBV_QPS_SQE,
Error = ibv_qp_state::IBV_QPS_ERR,
Unknown = ibv_qp_state::IBV_QPS_UNKNOWN,
}
impl From<u32> for QueuePairState {
fn from(state: u32) -> Self {
match state {
ibv_qp_state::IBV_QPS_RESET => QueuePairState::Reset,
ibv_qp_state::IBV_QPS_INIT => QueuePairState::Init,
ibv_qp_state::IBV_QPS_RTR => QueuePairState::ReadyToReceive,
ibv_qp_state::IBV_QPS_RTS => QueuePairState::ReadyToSend,
ibv_qp_state::IBV_QPS_SQD => QueuePairState::SendQueueDrain,
ibv_qp_state::IBV_QPS_SQE => QueuePairState::SendQueueError,
ibv_qp_state::IBV_QPS_ERR => QueuePairState::Error,
ibv_qp_state::IBV_QPS_UNKNOWN => QueuePairState::Unknown,
_ => panic!("Unknown qp state: {state}"),
}
}
}
#[bitmask(u64)]
#[bitmask_config(vec_debug)]
pub enum SendOperationFlags {
Write = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_RDMA_WRITE.0 as _,
WriteWithImmediate = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_RDMA_WRITE_WITH_IMM.0 as _,
Send = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_SEND.0 as _,
SendWithImmediate = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_SEND_WITH_IMM.0 as _,
Read = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_RDMA_READ.0 as _,
AtomicCompareAndSwap = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_ATOMIC_CMP_AND_SWP.0 as _,
AtomicFetchAndAdd = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_ATOMIC_FETCH_AND_ADD.0 as _,
LocalInvalidate = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_LOCAL_INV.0 as _,
BindMemoryWindow = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_BIND_MW.0 as _,
SendWithInvalidate = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_SEND_WITH_INV.0 as _,
TcpSegmentationOffload = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_TSO.0 as _,
Flush = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_FLUSH.0 as _,
AtomicWrite = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_ATOMIC_WRITE.0 as _,
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WorkRequestOperationType {
Send = ibv_wr_opcode::IBV_WR_SEND,
SendWithImmediate = ibv_wr_opcode::IBV_WR_SEND_WITH_IMM,
Write = ibv_wr_opcode::IBV_WR_RDMA_WRITE,
WriteWithImmediate = ibv_wr_opcode::IBV_WR_RDMA_WRITE_WITH_IMM,
Read = ibv_wr_opcode::IBV_WR_RDMA_READ,
AtomicCompareAndSwap = ibv_wr_opcode::IBV_WR_ATOMIC_CMP_AND_SWP,
AtomicFetchAndAdd = ibv_wr_opcode::IBV_WR_ATOMIC_FETCH_AND_ADD,
LocalInvalidate = ibv_wr_opcode::IBV_WR_LOCAL_INV,
BindMemoryWindow = ibv_wr_opcode::IBV_WR_BIND_MW,
SendWithInvalidate = ibv_wr_opcode::IBV_WR_SEND_WITH_INV,
TcpSegmentationOffload = ibv_wr_opcode::IBV_WR_TSO,
Driver1 = ibv_wr_opcode::IBV_WR_DRIVER1,
Flush = ibv_wr_opcode::IBV_WR_FLUSH,
AtomicWrite = ibv_wr_opcode::IBV_WR_ATOMIC_WRITE,
}
impl From<u32> for WorkRequestOperationType {
fn from(opcode: u32) -> Self {
match opcode {
ibv_wr_opcode::IBV_WR_SEND => WorkRequestOperationType::Send,
ibv_wr_opcode::IBV_WR_SEND_WITH_IMM => WorkRequestOperationType::SendWithImmediate,
ibv_wr_opcode::IBV_WR_RDMA_WRITE => WorkRequestOperationType::Write,
ibv_wr_opcode::IBV_WR_RDMA_WRITE_WITH_IMM => WorkRequestOperationType::WriteWithImmediate,
ibv_wr_opcode::IBV_WR_RDMA_READ => WorkRequestOperationType::Read,
ibv_wr_opcode::IBV_WR_ATOMIC_CMP_AND_SWP => WorkRequestOperationType::AtomicCompareAndSwap,
ibv_wr_opcode::IBV_WR_ATOMIC_FETCH_AND_ADD => WorkRequestOperationType::AtomicFetchAndAdd,
ibv_wr_opcode::IBV_WR_LOCAL_INV => WorkRequestOperationType::LocalInvalidate,
ibv_wr_opcode::IBV_WR_BIND_MW => WorkRequestOperationType::BindMemoryWindow,
ibv_wr_opcode::IBV_WR_SEND_WITH_INV => WorkRequestOperationType::SendWithInvalidate,
ibv_wr_opcode::IBV_WR_TSO => WorkRequestOperationType::TcpSegmentationOffload,
ibv_wr_opcode::IBV_WR_DRIVER1 => WorkRequestOperationType::Driver1,
ibv_wr_opcode::IBV_WR_FLUSH => WorkRequestOperationType::Flush,
ibv_wr_opcode::IBV_WR_ATOMIC_WRITE => WorkRequestOperationType::AtomicWrite,
_ => panic!("Unknown work request opcode: {opcode}"),
}
}
}
#[bitmask(u32)]
#[bitmask_config(vec_debug)]
pub enum WorkRequestFlags {
Fence = ibv_send_flags::IBV_SEND_FENCE.0,
Signaled = ibv_send_flags::IBV_SEND_SIGNALED.0,
Solicited = ibv_send_flags::IBV_SEND_SOLICITED.0,
Inline = ibv_send_flags::IBV_SEND_INLINE.0,
IpChecksum = ibv_send_flags::IBV_SEND_IP_CSUM.0,
}
#[allow(private_bounds)]
pub trait QueuePair {
unsafe fn qp(&self) -> NonNull<ibv_qp>;
fn modify(&mut self, attr: &QueuePairAttribute) -> Result<(), ModifyQueuePairError> {
let mut qp_attr = ibv_qp_attr { ..attr.attr };
let ret = unsafe { ibv_modify_qp(self.qp().as_ptr(), &mut qp_attr as *mut _, attr.attr_mask.bits) };
if ret == 0 {
Ok(())
} else {
match ret {
libc::EINVAL => {
let err = if attr.attr_mask.contains(QueuePairAttributeMask::State) {
attr_mask_check(attr.attr_mask, self.state(), attr.attr.qp_state.into())
} else {
attr_mask_check(attr.attr_mask, self.state(), self.state())
};
match err {
Ok(()) => {
Err(ModifyQueuePairErrorKind::Ibverbs(io::Error::from_raw_os_error(libc::EINVAL)).into())
},
Err(err) => Err(err),
}
},
libc::ETIMEDOUT => Err(ModifyQueuePairErrorKind::ResolveRouteTimedout {
sgid_index: attr.attr.ah_attr.grh.sgid_index,
gid: attr.attr.ah_attr.grh.dgid.into(),
source: io::Error::from_raw_os_error(libc::ETIMEDOUT),
}
.into()),
libc::ENETUNREACH => Err(ModifyQueuePairErrorKind::NetworkUnreachable {
sgid_index: attr.attr.ah_attr.grh.sgid_index,
gid: attr.attr.ah_attr.grh.dgid.into(),
source: io::Error::from_raw_os_error(libc::ENETUNREACH),
}
.into()),
err => Err(ModifyQueuePairErrorKind::Ibverbs(io::Error::from_raw_os_error(err)).into()),
}
}
}
fn query(
&self, mask: QueuePairAttributeMask,
) -> Result<(QueuePairAttribute, QueuePairInitAttribute), QueryQueuePairError> {
let mut attr = QueuePairAttribute::default();
let mut init_attr = QueuePairInitAttribute::default();
attr.attr_mask = mask;
let result = unsafe {
ibv_query_qp(
self.qp().as_ptr(),
&mut attr.attr as *mut _,
mask.bits(),
&mut init_attr.init_attr as *mut _,
)
};
match result {
0 => Ok((attr, init_attr)),
err => Err(QueryQueuePairErrorKind::Ibverbs(io::Error::from_raw_os_error(err)).into()),
}
}
fn state(&self) -> QueuePairState {
unsafe { self.qp().as_ref().state.into() }
}
fn qp_number(&self) -> u32 {
unsafe { self.qp().as_ref().qp_num }
}
type Guard<'g>: PostSendGuard
where
Self: 'g;
fn start_post_send(&mut self) -> Self::Guard<'_>;
fn start_post_recv(&mut self) -> PostRecvGuard<'_> {
PostRecvGuard {
qp: unsafe { self.qp() },
wrs: Vec::new(),
sges: Vec::new(),
_phantom: PhantomData,
}
}
}
mod private_traits {
use std::io::IoSlice;
use rdma_mummy_sys::ibv_sge;
pub trait PostSendGuard {
fn setup_send(&mut self);
fn setup_send_imm(&mut self, imm_data: u32);
fn setup_write(&mut self, rkey: u32, remote_addr: u64);
fn setup_write_imm(&mut self, rkey: u32, remote_addr: u64, imm_data: u32);
fn setup_read(&mut self, rkey: u32, remote_addr: u64);
fn setup_inline_data(&mut self, buf: &[u8]);
fn setup_inline_data_list(&mut self, bufs: &[IoSlice<'_>]);
unsafe fn setup_sge(&mut self, lkey: u32, addr: u64, length: u32);
unsafe fn setup_sge_list(&mut self, sg_list: &[ibv_sge]);
}
}
pub trait PostSendGuard: private_traits::PostSendGuard {
fn construct_wr(&mut self, wr_id: u64, wr_flags: WorkRequestFlags) -> WorkRequestHandle<'_, Self>;
fn post(self) -> Result<(), PostSendError>;
}
#[bitmask(i32)]
#[bitmask_config(vec_debug)]
pub enum QueuePairAttributeMask {
State = ibv_qp_attr_mask::IBV_QP_STATE.0 as _,
CurrentState = ibv_qp_attr_mask::IBV_QP_CUR_STATE.0 as _,
EnableSendQueueDrainedAsyncNotify = ibv_qp_attr_mask::IBV_QP_EN_SQD_ASYNC_NOTIFY.0 as _,
AccessFlags = ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS.0 as _,
PartitionKeyIndex = ibv_qp_attr_mask::IBV_QP_PKEY_INDEX.0 as _,
Port = ibv_qp_attr_mask::IBV_QP_PORT.0 as _,
QueueKey = ibv_qp_attr_mask::IBV_QP_QKEY.0 as _,
AddressVector = ibv_qp_attr_mask::IBV_QP_AV.0 as _,
PathMtu = ibv_qp_attr_mask::IBV_QP_PATH_MTU.0 as _,
Timeout = ibv_qp_attr_mask::IBV_QP_TIMEOUT.0 as _,
RetryCount = ibv_qp_attr_mask::IBV_QP_RETRY_CNT.0 as _,
ResponderNotReadyRetryCount = ibv_qp_attr_mask::IBV_QP_RNR_RETRY.0 as _,
ReceiveQueuePacketSequenceNumber = ibv_qp_attr_mask::IBV_QP_RQ_PSN.0 as _,
MaxReadAtomic = ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC.0 as _,
AlternatePath = ibv_qp_attr_mask::IBV_QP_ALT_PATH.0 as _,
MinResponderNotReadyTimer = ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER.0 as _,
SendQueuePacketSequenceNumber = ibv_qp_attr_mask::IBV_QP_SQ_PSN.0 as _,
MaxDestinationReadAtomic = ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC.0 as _,
PathMigrationState = ibv_qp_attr_mask::IBV_QP_PATH_MIG_STATE.0 as _,
Capabilities = ibv_qp_attr_mask::IBV_QP_CAP.0 as _,
DestinationQueuePairNumber = ibv_qp_attr_mask::IBV_QP_DEST_QPN.0 as _,
RateLimit = ibv_qp_attr_mask::IBV_QP_RATE_LIMIT.0 as _,
}
#[derive(Debug, Copy, Clone)]
struct QueuePairStateTableEntry {
valid: bool,
required_mask: QueuePairAttributeMask,
optional_mask: QueuePairAttributeMask,
}
static RC_QP_STATE_TABLE: LazyLock<
[[QueuePairStateTableEntry; QueuePairState::Error as usize + 1]; QueuePairState::Error as usize + 1],
> = LazyLock::new(|| {
use QueuePairState::*;
let mut qp_state_table = [[QueuePairStateTableEntry {
valid: false,
required_mask: QueuePairAttributeMask { bits: 0 },
optional_mask: QueuePairAttributeMask { bits: 0 },
}; Error as usize + 1]; Error as usize + 1];
let mut state = Reset;
while state <= Error {
qp_state_table[state as usize][Reset as usize] = QueuePairStateTableEntry {
valid: true,
required_mask: QueuePairAttributeMask::State,
optional_mask: QueuePairAttributeMask { bits: 0 },
};
qp_state_table[state as usize][Error as usize] = QueuePairStateTableEntry {
valid: true,
required_mask: QueuePairAttributeMask::State,
optional_mask: QueuePairAttributeMask { bits: 0 },
};
state = (state as u32 + 1).into()
}
qp_state_table[Reset as usize][Init as usize] = QueuePairStateTableEntry {
valid: true,
required_mask: QueuePairAttributeMask::State
| QueuePairAttributeMask::PartitionKeyIndex
| QueuePairAttributeMask::Port
| QueuePairAttributeMask::AccessFlags,
optional_mask: QueuePairAttributeMask { bits: 0 },
};
qp_state_table[Init as usize][Init as usize] = QueuePairStateTableEntry {
valid: true,
required_mask: QueuePairAttributeMask { bits: 0 },
optional_mask: QueuePairAttributeMask::PartitionKeyIndex
| QueuePairAttributeMask::Port
| QueuePairAttributeMask::AccessFlags,
};
qp_state_table[Init as usize][ReadyToReceive as usize] = QueuePairStateTableEntry {
valid: true,
required_mask: QueuePairAttributeMask::State
| QueuePairAttributeMask::AddressVector
| QueuePairAttributeMask::PathMtu
| QueuePairAttributeMask::DestinationQueuePairNumber
| QueuePairAttributeMask::ReceiveQueuePacketSequenceNumber
| QueuePairAttributeMask::MaxDestinationReadAtomic
| QueuePairAttributeMask::MinResponderNotReadyTimer,
optional_mask: QueuePairAttributeMask::PartitionKeyIndex
| QueuePairAttributeMask::AccessFlags
| QueuePairAttributeMask::AlternatePath,
};
qp_state_table[ReadyToReceive as usize][ReadyToSend as usize] = QueuePairStateTableEntry {
valid: true,
required_mask: QueuePairAttributeMask::State
| QueuePairAttributeMask::SendQueuePacketSequenceNumber
| QueuePairAttributeMask::Timeout
| QueuePairAttributeMask::RetryCount
| QueuePairAttributeMask::ResponderNotReadyRetryCount
| QueuePairAttributeMask::MaxReadAtomic,
optional_mask: QueuePairAttributeMask::CurrentState
| QueuePairAttributeMask::AccessFlags
| QueuePairAttributeMask::MinResponderNotReadyTimer
| QueuePairAttributeMask::AlternatePath
| QueuePairAttributeMask::PathMigrationState,
};
qp_state_table[ReadyToSend as usize][ReadyToSend as usize] = QueuePairStateTableEntry {
valid: true,
required_mask: QueuePairAttributeMask { bits: 0 },
optional_mask: QueuePairAttributeMask::CurrentState
| QueuePairAttributeMask::AccessFlags
| QueuePairAttributeMask::MinResponderNotReadyTimer
| QueuePairAttributeMask::AlternatePath
| QueuePairAttributeMask::PathMigrationState,
};
qp_state_table[ReadyToSend as usize][SendQueueDrain as usize] = QueuePairStateTableEntry {
valid: true,
required_mask: QueuePairAttributeMask::State,
optional_mask: QueuePairAttributeMask::EnableSendQueueDrainedAsyncNotify,
};
qp_state_table[SendQueueDrain as usize][ReadyToSend as usize] = QueuePairStateTableEntry {
valid: true,
required_mask: QueuePairAttributeMask::State,
optional_mask: QueuePairAttributeMask::CurrentState
| QueuePairAttributeMask::AccessFlags
| QueuePairAttributeMask::MinResponderNotReadyTimer
| QueuePairAttributeMask::AlternatePath
| QueuePairAttributeMask::PathMigrationState,
};
qp_state_table[SendQueueDrain as usize][SendQueueDrain as usize] = QueuePairStateTableEntry {
valid: true,
required_mask: QueuePairAttributeMask { bits: 0 },
optional_mask: QueuePairAttributeMask::PartitionKeyIndex
| QueuePairAttributeMask::Port
| QueuePairAttributeMask::AccessFlags
| QueuePairAttributeMask::AddressVector
| QueuePairAttributeMask::MaxReadAtomic
| QueuePairAttributeMask::MinResponderNotReadyTimer
| QueuePairAttributeMask::AlternatePath
| QueuePairAttributeMask::Timeout
| QueuePairAttributeMask::RetryCount
| QueuePairAttributeMask::ResponderNotReadyRetryCount
| QueuePairAttributeMask::MaxDestinationReadAtomic
| QueuePairAttributeMask::PathMigrationState,
};
qp_state_table
});
pub struct BasicQueuePair {
pub(crate) qp: NonNull<ibv_qp>,
_pd: Arc<ProtectionDomain>,
_send_cq: GenericCompletionQueue,
_recv_cq: GenericCompletionQueue,
}
unsafe impl Send for BasicQueuePair {}
unsafe impl Sync for BasicQueuePair {}
impl Drop for BasicQueuePair {
fn drop(&mut self) {
let ret = unsafe { ibv_destroy_qp(self.qp.as_ptr()) };
assert_eq!(ret, 0);
}
}
impl QueuePair for BasicQueuePair {
unsafe fn qp(&self) -> NonNull<ibv_qp> {
self.qp
}
type Guard<'g>
= BasicPostSendGuard<'g>
where
Self: 'g;
fn start_post_send(&mut self) -> Self::Guard<'_> {
BasicPostSendGuard {
qp: self.qp,
wrs: Vec::with_capacity(0),
sges: Vec::with_capacity(0),
inline_buffers: Vec::with_capacity(0),
_phantom: PhantomData,
}
}
}
impl fmt::Debug for BasicQueuePair {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BasicQueuePair").field("qp", &self.qp).finish()
}
}
pub struct ExtendedQueuePair {
pub(crate) qp_ex: NonNull<ibv_qp_ex>,
_pd: Arc<ProtectionDomain>,
_send_cq: GenericCompletionQueue,
_recv_cq: GenericCompletionQueue,
}
unsafe impl Send for ExtendedQueuePair {}
unsafe impl Sync for ExtendedQueuePair {}
impl Drop for ExtendedQueuePair {
fn drop(&mut self) {
let ret = unsafe { ibv_destroy_qp(self.qp().as_ptr()) };
assert_eq!(ret, 0)
}
}
impl QueuePair for ExtendedQueuePair {
unsafe fn qp(&self) -> NonNull<ibv_qp> {
NonNull::new_unchecked(&mut (*self.qp_ex.as_ptr()).qp_base as _)
}
type Guard<'g>
= ExtendedPostSendGuard<'g>
where
Self: 'g;
fn start_post_send(&mut self) -> Self::Guard<'_> {
unsafe {
ibv_wr_start(self.qp().as_ptr() as _);
}
ExtendedPostSendGuard {
qp_ex: self.qp_ex,
_phantom: PhantomData,
}
}
}
impl fmt::Debug for ExtendedQueuePair {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtendedQueuePair").field("qp_ex", &self.qp_ex).finish()
}
}
pub struct QueuePairBuilder {
init_attr: ibv_qp_init_attr_ex,
pd: Arc<ProtectionDomain>,
send_cq: Option<GenericCompletionQueue>,
recv_cq: Option<GenericCompletionQueue>,
}
impl QueuePairBuilder {
pub fn new(pd: &Arc<ProtectionDomain>) -> QueuePairBuilder {
QueuePairBuilder {
init_attr: ibv_qp_init_attr_ex {
qp_context: null_mut(),
send_cq: null_mut(),
recv_cq: null_mut(),
srq: null_mut(),
cap: ibv_qp_cap {
max_send_wr: 16,
max_recv_wr: 16,
max_send_sge: 1,
max_recv_sge: 1,
max_inline_data: 0,
},
qp_type: QueuePairType::ReliableConnection as _,
sq_sig_all: 0,
comp_mask: ibv_qp_init_attr_mask::IBV_QP_INIT_ATTR_PD.0
| ibv_qp_init_attr_mask::IBV_QP_INIT_ATTR_SEND_OPS_FLAGS.0,
pd: pd.pd.as_ptr(),
xrcd: null_mut(),
create_flags: 0,
max_tso_header: 0,
rwq_ind_tbl: null_mut(),
rx_hash_conf: unsafe { MaybeUninit::<ibv_rx_hash_conf>::zeroed().assume_init() },
source_qpn: 0,
send_ops_flags: (SendOperationFlags::Send
| SendOperationFlags::SendWithImmediate
| SendOperationFlags::Write
| SendOperationFlags::WriteWithImmediate
| SendOperationFlags::Read)
.into(),
},
pd: Arc::clone(pd),
send_cq: None,
recv_cq: None,
}
}
pub fn setup_qp_type(&mut self, qp_type: QueuePairType) -> &mut Self {
self.init_attr.qp_type = qp_type as u32;
self
}
pub fn setup_max_send_wr(&mut self, max_send_wr: u32) -> &mut Self {
self.init_attr.cap.max_send_wr = max_send_wr;
self
}
pub fn setup_max_recv_wr(&mut self, max_recv_wr: u32) -> &mut Self {
self.init_attr.cap.max_recv_wr = max_recv_wr;
self
}
pub fn setup_max_send_sge(&mut self, max_send_sge: u32) -> &mut Self {
self.init_attr.cap.max_send_sge = max_send_sge;
self
}
pub fn setup_max_recv_sge(&mut self, max_recv_sge: u32) -> &mut Self {
self.init_attr.cap.max_recv_sge = max_recv_sge;
self
}
pub fn setup_max_inline_data(&mut self, max_inline_data: u32) -> &mut Self {
self.init_attr.cap.max_inline_data = max_inline_data;
self
}
pub fn setup_send_cq<C>(&mut self, send_cq: C) -> &mut Self
where
C: Into<GenericCompletionQueue>,
{
let cq = send_cq.into();
unsafe {
self.init_attr.send_cq = cq.cq().as_ptr();
}
self.send_cq = Some(cq);
self
}
pub fn setup_recv_cq<C>(&mut self, recv_cq: C) -> &mut Self
where
C: Into<GenericCompletionQueue>,
{
let cq = recv_cq.into();
unsafe {
self.init_attr.recv_cq = cq.cq().as_ptr();
}
self.recv_cq = Some(cq);
self
}
pub fn setup_send_ops_flags(&mut self, send_ops_flags: SendOperationFlags) -> &mut Self {
self.init_attr.send_ops_flags = send_ops_flags.bits;
self
}
pub fn build(&self) -> Result<BasicQueuePair, CreateQueuePairError> {
let send_cq = self
.send_cq
.as_ref()
.cloned()
.expect("send completion queue must be configured before building a QueuePair");
let recv_cq = self
.recv_cq
.as_ref()
.cloned()
.expect("receive completion queue must be configured before building a QueuePair");
let qp = unsafe {
ibv_create_qp(
self.init_attr.pd,
&mut ibv_qp_init_attr {
qp_context: null_mut(),
send_cq: self.init_attr.send_cq,
recv_cq: self.init_attr.recv_cq,
srq: null_mut(),
cap: self.init_attr.cap,
qp_type: QueuePairType::ReliableConnection as _,
sq_sig_all: 0,
},
)
};
Ok(BasicQueuePair {
qp: NonNull::new(qp)
.ok_or::<CreateQueuePairError>(CreateQueuePairErrorKind::Ibverbs(io::Error::last_os_error()).into())?,
_pd: Arc::clone(&self.pd),
_send_cq: send_cq,
_recv_cq: recv_cq,
})
}
pub fn build_ex(&self) -> Result<ExtendedQueuePair, CreateQueuePairError> {
let send_cq = self
.send_cq
.as_ref()
.cloned()
.expect("send completion queue must be configured before building a QueuePair");
let recv_cq = self
.recv_cq
.as_ref()
.cloned()
.expect("receive completion queue must be configured before building a QueuePair");
let mut attr = self.init_attr;
let qp = unsafe { ibv_create_qp_ex((*(attr.pd)).context, &mut attr) };
Ok(ExtendedQueuePair {
qp_ex: NonNull::new(unsafe { ibv_qp_to_qp_ex(qp) })
.ok_or::<CreateQueuePairError>(CreateQueuePairErrorKind::Ibverbs(io::Error::last_os_error()).into())?,
_pd: Arc::clone(&self.pd),
_send_cq: send_cq,
_recv_cq: recv_cq,
})
}
}
pub struct QueuePairAttribute {
attr: ibv_qp_attr,
attr_mask: QueuePairAttributeMask,
}
impl Default for QueuePairAttribute {
fn default() -> Self {
Self::new()
}
}
impl QueuePairAttribute {
pub fn new() -> Self {
QueuePairAttribute {
attr: unsafe { MaybeUninit::zeroed().assume_init() },
attr_mask: QueuePairAttributeMask { bits: 0 },
}
}
pub fn from(attr: &ibv_qp_attr, attr_mask: i32) -> Self {
QueuePairAttribute {
attr: ibv_qp_attr { ..*attr },
attr_mask: QueuePairAttributeMask { bits: attr_mask },
}
}
pub fn setup_state(&mut self, state: QueuePairState) -> &mut Self {
self.attr.qp_state = state as _;
self.attr_mask |= QueuePairAttributeMask::State;
self
}
pub fn state(&self) -> QueuePairState {
self.attr.qp_state.into()
}
pub fn setup_pkey_index(&mut self, pkey_index: u16) -> &mut Self {
self.attr.pkey_index = pkey_index;
self.attr_mask |= QueuePairAttributeMask::PartitionKeyIndex;
self
}
pub fn pkey_index(&self) -> u16 {
self.attr.pkey_index
}
pub fn setup_port(&mut self, port_num: u8) -> &mut Self {
self.attr.port_num = port_num;
self.attr_mask |= QueuePairAttributeMask::Port;
self
}
pub fn port(&self) -> u8 {
self.attr.port_num
}
pub fn setup_access_flags(&mut self, access_flags: AccessFlags) -> &mut Self {
self.attr.qp_access_flags = access_flags.bits as _;
self.attr_mask |= QueuePairAttributeMask::AccessFlags;
self
}
pub fn access_flags(&self) -> AccessFlags {
AccessFlags::from(self.attr.qp_access_flags as i32)
}
pub fn setup_path_mtu(&mut self, path_mtu: Mtu) -> &mut Self {
self.attr.path_mtu = path_mtu as _;
self.attr_mask |= QueuePairAttributeMask::PathMtu;
self
}
pub fn path_mtu(&self) -> Mtu {
self.attr.path_mtu.into()
}
pub fn setup_dest_qp_num(&mut self, dest_qp_num: u32) -> &mut Self {
self.attr.dest_qp_num = dest_qp_num;
self.attr_mask |= QueuePairAttributeMask::DestinationQueuePairNumber;
self
}
pub fn dest_qp_num(&self) -> u32 {
self.attr.dest_qp_num
}
pub fn setup_rq_psn(&mut self, rq_psn: u32) -> &mut Self {
self.attr.rq_psn = rq_psn;
self.attr_mask |= QueuePairAttributeMask::ReceiveQueuePacketSequenceNumber;
self
}
pub fn rq_psn(&self) -> u32 {
self.attr.rq_psn
}
pub fn setup_sq_psn(&mut self, sq_psn: u32) -> &mut Self {
self.attr.sq_psn = sq_psn;
self.attr_mask |= QueuePairAttributeMask::SendQueuePacketSequenceNumber;
self
}
pub fn sq_psn(&self) -> u32 {
self.attr.sq_psn
}
pub fn setup_max_read_atomic(&mut self, max_read_atomic: u8) -> &mut Self {
self.attr.max_rd_atomic = max_read_atomic;
self.attr_mask |= QueuePairAttributeMask::MaxReadAtomic;
self
}
pub fn max_read_atomic(&self) -> u8 {
self.attr.max_rd_atomic
}
pub fn setup_max_dest_read_atomic(&mut self, max_dest_read_atomic: u8) -> &mut Self {
self.attr.max_dest_rd_atomic = max_dest_read_atomic;
self.attr_mask |= QueuePairAttributeMask::MaxDestinationReadAtomic;
self
}
pub fn max_dest_read_atomic(&self) -> u8 {
self.attr.max_dest_rd_atomic
}
pub fn setup_min_rnr_timer(&mut self, min_rnr_timer: u8) -> &mut Self {
self.attr.min_rnr_timer = min_rnr_timer;
self.attr_mask |= QueuePairAttributeMask::MinResponderNotReadyTimer;
self
}
pub fn min_rnr_timer(&self) -> u8 {
self.attr.min_rnr_timer
}
pub fn setup_timeout(&mut self, timeout: u8) -> &mut Self {
self.attr.timeout = timeout;
self.attr_mask |= QueuePairAttributeMask::Timeout;
self
}
pub fn timeout(&self) -> u8 {
self.attr.timeout
}
pub fn setup_retry_cnt(&mut self, retry_cnt: u8) -> &mut Self {
self.attr.retry_cnt = retry_cnt;
self.attr_mask |= QueuePairAttributeMask::RetryCount;
self
}
pub fn retry_cnt(&self) -> u8 {
self.attr.retry_cnt
}
pub fn setup_rnr_retry(&mut self, rnr_retry: u8) -> &mut Self {
self.attr.rnr_retry = rnr_retry;
self.attr_mask |= QueuePairAttributeMask::ResponderNotReadyRetryCount;
self
}
pub fn rnr_retry(&self) -> u8 {
self.attr.rnr_retry
}
pub fn setup_address_vector(&mut self, ah_attr: &AddressHandleAttribute) -> &mut Self {
self.attr.ah_attr = ah_attr.attr;
self.attr_mask |= QueuePairAttributeMask::AddressVector;
self
}
}
pub struct QueuePairInitAttribute {
init_attr: ibv_qp_init_attr,
}
impl Default for QueuePairInitAttribute {
fn default() -> Self {
Self::new()
}
}
impl QueuePairInitAttribute {
pub fn new() -> Self {
QueuePairInitAttribute {
init_attr: unsafe { MaybeUninit::zeroed().assume_init() },
}
}
pub fn max_send_wr(&self) -> u32 {
self.init_attr.cap.max_send_wr
}
pub fn max_recv_wr(&self) -> u32 {
self.init_attr.cap.max_recv_wr
}
pub fn max_send_sge(&self) -> u32 {
self.init_attr.cap.max_send_sge
}
pub fn max_recv_sge(&self) -> u32 {
self.init_attr.cap.max_recv_sge
}
pub fn max_inline_data(&self) -> u32 {
self.init_attr.cap.max_inline_data
}
}
#[inline]
fn get_needed_mask(cur_mask: QueuePairAttributeMask, required_mask: QueuePairAttributeMask) -> QueuePairAttributeMask {
required_mask.and(required_mask.xor(cur_mask))
}
#[inline]
fn get_invalid_mask(
cur_mask: QueuePairAttributeMask, required_mask: QueuePairAttributeMask, optional_mask: QueuePairAttributeMask,
) -> QueuePairAttributeMask {
cur_mask.and(required_mask.or(optional_mask).not())
}
fn attr_mask_check(
attr_mask: QueuePairAttributeMask, cur_state: QueuePairState, next_state: QueuePairState,
) -> Result<(), ModifyQueuePairError> {
if !RC_QP_STATE_TABLE[cur_state as usize][next_state as usize].valid {
return Err(ModifyQueuePairErrorKind::InvalidTransition {
cur_state,
next_state,
source: io::Error::from_raw_os_error(libc::EINVAL),
}
.into());
}
let required = RC_QP_STATE_TABLE[cur_state as usize][next_state as usize].required_mask;
let optional = RC_QP_STATE_TABLE[cur_state as usize][next_state as usize].optional_mask;
let invalid = get_invalid_mask(attr_mask, required, optional);
let needed = get_needed_mask(attr_mask, required);
if invalid.bits == 0 && needed.bits == 0 {
Ok(())
} else {
Err(ModifyQueuePairErrorKind::InvalidAttributeMask {
cur_state,
next_state,
invalid,
needed,
source: io::Error::from_raw_os_error(libc::EINVAL),
}
.into())
}
}
pub struct WorkRequestHandle<'g, G: PostSendGuard + ?Sized> {
guard: &'g mut G,
}
pub trait SetScatterGatherEntry {
unsafe fn setup_sge(self, lkey: u32, addr: u64, length: u32);
unsafe fn setup_sge_list(self, sg_list: &[ibv_sge]);
}
pub trait SetInlineData {
fn setup_inline_data(self, buf: &[u8]);
fn setup_inline_data_list(self, bufs: &[IoSlice<'_>]);
}
pub struct LocalBufferHandle<'g, G: PostSendGuard> {
guard: &'g mut G,
}
impl<G: PostSendGuard> SetInlineData for LocalBufferHandle<'_, G> {
fn setup_inline_data(self, buf: &[u8]) {
self.guard.setup_inline_data(buf);
}
fn setup_inline_data_list(self, bufs: &[IoSlice<'_>]) {
self.guard.setup_inline_data_list(bufs);
}
}
impl<G: PostSendGuard> SetScatterGatherEntry for LocalBufferHandle<'_, G> {
unsafe fn setup_sge(self, lkey: u32, addr: u64, length: u32) {
self.guard.setup_sge(lkey, addr, length);
}
unsafe fn setup_sge_list(self, sg_list: &[ibv_sge]) {
self.guard.setup_sge_list(sg_list);
}
}
impl<'g, G: PostSendGuard> WorkRequestHandle<'g, G> {
pub fn setup_send(self) -> LocalBufferHandle<'g, G> {
self.guard.setup_send();
LocalBufferHandle { guard: self.guard }
}
pub fn setup_send_imm(self, imm_data: u32) -> LocalBufferHandle<'g, G> {
self.guard.setup_send_imm(imm_data);
LocalBufferHandle { guard: self.guard }
}
pub fn setup_write(self, rkey: u32, remote_addr: u64) -> LocalBufferHandle<'g, G> {
self.guard.setup_write(rkey, remote_addr);
LocalBufferHandle { guard: self.guard }
}
pub fn setup_write_imm(self, rkey: u32, remote_addr: u64, imm_data: u32) -> LocalBufferHandle<'g, G> {
self.guard.setup_write_imm(rkey, remote_addr, imm_data);
LocalBufferHandle { guard: self.guard }
}
pub fn setup_read(self, rkey: u32, remote_addr: u64) -> LocalBufferHandle<'g, G> {
self.guard.setup_read(rkey, remote_addr);
LocalBufferHandle { guard: self.guard }
}
}
pub struct BasicPostSendGuard<'g> {
qp: NonNull<ibv_qp>,
wrs: Vec<ibv_send_wr>,
sges: Vec<ibv_sge>,
inline_buffers: Vec<Vec<u8>>,
_phantom: PhantomData<&'g ()>,
}
impl PostSendGuard for BasicPostSendGuard<'_> {
fn construct_wr(&mut self, wr_id: u64, wr_flags: WorkRequestFlags) -> WorkRequestHandle<'_, Self> {
self.wrs.push(ibv_send_wr {
wr_id,
next: null_mut(),
sg_list: null_mut(),
num_sge: 0,
opcode: 0,
send_flags: wr_flags.bits,
..unsafe { MaybeUninit::zeroed().assume_init() }
});
WorkRequestHandle { guard: self }
}
fn post(mut self) -> Result<(), PostSendError> {
let mut sge_index = 0;
for i in 0..self.wrs.len() {
if i < self.wrs.len() - 1 {
self.wrs[i].next = &mut self.wrs[i + 1] as *mut _;
} else {
self.wrs[i].next = null_mut();
}
if self.wrs[i].num_sge > 0 {
self.wrs[i].sg_list = &mut self.sges[sge_index] as *mut _;
sge_index += self.wrs[i].num_sge as usize;
}
}
let mut bad_wr: *mut ibv_send_wr = null_mut();
let ret = unsafe { ibv_post_send(self.qp.as_ptr(), self.wrs.as_mut_ptr(), &mut bad_wr) };
match ret {
0 => Ok(()),
libc::EINVAL => Err(PostSendError::InvalidWorkRequest(io::Error::from_raw_os_error(
libc::EINVAL,
))),
libc::ENOMEM => Err(PostSendError::NotEnoughResources(io::Error::from_raw_os_error(
libc::ENOMEM,
))),
libc::EFAULT => Err(PostSendError::InvalidQueuePair(io::Error::from_raw_os_error(
libc::EFAULT,
))),
err => Err(PostSendError::Ibverbs(io::Error::from_raw_os_error(err))),
}
}
}
impl private_traits::PostSendGuard for BasicPostSendGuard<'_> {
fn setup_send(&mut self) {
self.wrs.last_mut().unwrap().opcode = WorkRequestOperationType::Send as _;
}
fn setup_send_imm(&mut self, imm_data: u32) {
self.wrs.last_mut().unwrap().opcode = WorkRequestOperationType::SendWithImmediate as _;
self.wrs.last_mut().unwrap().imm_data_invalidated_rkey_union.imm_data = imm_data;
}
fn setup_write(&mut self, rkey: u32, remote_addr: u64) {
self.wrs.last_mut().unwrap().opcode = WorkRequestOperationType::Write as _;
self.wrs.last_mut().unwrap().wr.rdma.remote_addr = remote_addr;
self.wrs.last_mut().unwrap().wr.rdma.rkey = rkey;
}
fn setup_write_imm(&mut self, rkey: u32, remote_addr: u64, imm_data: u32) {
self.wrs.last_mut().unwrap().opcode = WorkRequestOperationType::WriteWithImmediate as _;
self.wrs.last_mut().unwrap().wr.rdma.remote_addr = remote_addr;
self.wrs.last_mut().unwrap().wr.rdma.rkey = rkey;
self.wrs.last_mut().unwrap().imm_data_invalidated_rkey_union.imm_data = imm_data;
}
fn setup_read(&mut self, rkey: u32, remote_addr: u64) {
self.wrs.last_mut().unwrap().opcode = WorkRequestOperationType::Read as _;
self.wrs.last_mut().unwrap().wr.rdma.remote_addr = remote_addr;
self.wrs.last_mut().unwrap().wr.rdma.rkey = rkey;
}
fn setup_inline_data(&mut self, buf: &[u8]) {
self.inline_buffers.push(Vec::from(buf));
unsafe {
self.sges.push(ibv_sge {
addr: self.inline_buffers.last().unwrap_unchecked().as_ptr() as u64,
length: self.inline_buffers.last().unwrap_unchecked().len() as u32,
lkey: 0,
});
}
self.wrs.last_mut().unwrap().send_flags |= WorkRequestFlags::Inline.bits;
self.wrs.last_mut().unwrap().num_sge += 1;
}
fn setup_inline_data_list(&mut self, bufs: &[IoSlice<'_>]) {
self.inline_buffers
.push(bufs.iter().fold(Vec::<u8>::new(), |mut res, slice| {
res.append(&mut slice.to_vec().clone());
res
}));
unsafe {
self.sges.push(ibv_sge {
addr: self.inline_buffers.last().unwrap_unchecked().as_ptr() as u64,
length: self.inline_buffers.last().unwrap_unchecked().len() as u32,
lkey: 0,
});
}
self.wrs.last_mut().unwrap().send_flags |= WorkRequestFlags::Inline.bits;
self.wrs.last_mut().unwrap().num_sge += 1;
}
unsafe fn setup_sge(&mut self, lkey: u32, addr: u64, length: u32) {
self.sges.push(ibv_sge { addr, length, lkey });
self.wrs.last_mut().unwrap_unchecked().num_sge = 1;
}
unsafe fn setup_sge_list(&mut self, sg_list: &[ibv_sge]) {
self.sges.extend_from_slice(sg_list);
self.wrs.last_mut().unwrap_unchecked().num_sge = sg_list.len() as _;
}
}
pub struct ExtendedPostSendGuard<'qp> {
qp_ex: NonNull<ibv_qp_ex>,
_phantom: PhantomData<&'qp ()>,
}
impl PostSendGuard for ExtendedPostSendGuard<'_> {
fn construct_wr(&mut self, wr_id: u64, wr_flags: WorkRequestFlags) -> WorkRequestHandle<'_, Self> {
unsafe {
self.qp_ex.as_mut().wr_id = wr_id;
self.qp_ex.as_mut().wr_flags = wr_flags.bits;
}
WorkRequestHandle { guard: self }
}
fn post(self) -> Result<(), PostSendError> {
let ret: i32 = unsafe { ibv_wr_complete(self.qp_ex.as_ptr()) };
std::mem::forget(self);
match ret {
0 => Ok(()),
libc::EINVAL => Err(PostSendError::InvalidWorkRequest(io::Error::from_raw_os_error(
libc::EINVAL,
))),
libc::ENOMEM => Err(PostSendError::NotEnoughResources(io::Error::from_raw_os_error(
libc::ENOMEM,
))),
libc::EFAULT => Err(PostSendError::InvalidQueuePair(io::Error::from_raw_os_error(
libc::EFAULT,
))),
err => Err(PostSendError::Ibverbs(io::Error::from_raw_os_error(err))),
}
}
}
impl private_traits::PostSendGuard for ExtendedPostSendGuard<'_> {
fn setup_send(&mut self) {
unsafe { ibv_wr_send(self.qp_ex.as_ptr()) };
}
fn setup_send_imm(&mut self, imm_data: u32) {
unsafe { ibv_wr_send_imm(self.qp_ex.as_ptr(), imm_data) };
}
fn setup_write(&mut self, rkey: u32, remote_addr: u64) {
unsafe { ibv_wr_rdma_write(self.qp_ex.as_ptr(), rkey, remote_addr) };
}
fn setup_write_imm(&mut self, rkey: u32, remote_addr: u64, imm_data: u32) {
unsafe { ibv_wr_rdma_write_imm(self.qp_ex.as_ptr(), rkey, remote_addr, imm_data) };
}
fn setup_read(&mut self, rkey: u32, remote_addr: u64) {
unsafe { ibv_wr_rdma_read(self.qp_ex.as_ptr(), rkey, remote_addr) };
}
fn setup_inline_data(&mut self, buf: &[u8]) {
unsafe { ibv_wr_set_inline_data(self.qp_ex.as_ptr(), buf.as_ptr() as _, buf.len()) }
}
fn setup_inline_data_list(&mut self, bufs: &[IoSlice<'_>]) {
let mut buf_list = Vec::with_capacity(bufs.len());
buf_list.extend(bufs.iter().map(|x| ibv_data_buf {
addr: x.as_ptr() as _,
length: x.len(),
}));
unsafe { ibv_wr_set_inline_data_list(self.qp_ex.as_ptr(), buf_list.len(), buf_list.as_ptr()) };
}
unsafe fn setup_sge(&mut self, lkey: u32, addr: u64, length: u32) {
ibv_wr_set_sge(self.qp_ex.as_ptr(), lkey, addr, length);
}
unsafe fn setup_sge_list(&mut self, sg_list: &[ibv_sge]) {
ibv_wr_set_sge_list(self.qp_ex.as_ptr(), sg_list.len(), sg_list.as_ptr());
}
}
impl Drop for ExtendedPostSendGuard<'_> {
fn drop(&mut self) {
unsafe { ibv_wr_abort(self.qp_ex.as_ptr()) };
}
}
pub struct PostRecvGuard<'qp> {
qp: NonNull<ibv_qp>,
wrs: Vec<ibv_recv_wr>,
sges: Vec<ibv_sge>,
_phantom: PhantomData<&'qp ()>,
}
impl<'qp> PostRecvGuard<'qp> {
pub fn construct_wr<'g>(&'g mut self, wr_id: u64) -> RecvWorkRequestHandle<'g, 'qp> {
self.wrs.push(ibv_recv_wr {
wr_id,
next: null_mut(),
sg_list: null_mut(),
num_sge: 0,
});
RecvWorkRequestHandle { guard: self }
}
pub fn post(mut self) -> Result<(), PostRecvError> {
let mut sge_index = 0;
for i in 0..self.wrs.len() {
if i < self.wrs.len() - 1 {
self.wrs[i].next = &mut self.wrs[i + 1] as *mut _;
} else {
self.wrs[i].next = null_mut();
}
if self.wrs[i].num_sge > 0 {
self.wrs[i].sg_list = &mut self.sges[sge_index] as *mut _;
sge_index += self.wrs[i].num_sge as usize;
}
}
let mut bad_wr: *mut ibv_recv_wr = null_mut();
let ret = unsafe { ibv_post_recv(self.qp.as_ptr(), self.wrs.as_mut_ptr(), &mut bad_wr) };
match ret {
0 => Ok(()),
libc::EINVAL => Err(PostRecvError::InvalidWorkRequest(io::Error::from_raw_os_error(
libc::EINVAL,
))),
libc::ENOMEM => Err(PostRecvError::NotEnoughResources(io::Error::from_raw_os_error(
libc::ENOMEM,
))),
libc::EFAULT => Err(PostRecvError::InvalidQueuePair(io::Error::from_raw_os_error(
libc::EFAULT,
))),
err => Err(PostRecvError::Ibverbs(io::Error::from_raw_os_error(err))),
}
}
}
pub struct RecvWorkRequestHandle<'g, 'qp> {
guard: &'g mut PostRecvGuard<'qp>,
}
impl SetScatterGatherEntry for RecvWorkRequestHandle<'_, '_> {
unsafe fn setup_sge(self, lkey: u32, addr: u64, length: u32) {
assert!(!self.guard.wrs.is_empty());
self.guard.wrs.last_mut().unwrap_unchecked().num_sge = 1;
self.guard.sges.push(ibv_sge { addr, length, lkey });
}
unsafe fn setup_sge_list(self, sg_list: &[ibv_sge]) {
assert!(!self.guard.wrs.is_empty());
self.guard.wrs.last_mut().unwrap_unchecked().num_sge = sg_list.len() as _;
self.guard.sges.extend_from_slice(sg_list);
}
}
#[derive(Debug)]
pub enum GenericQueuePair {
Basic(BasicQueuePair),
Extended(ExtendedQueuePair),
}
impl QueuePair for GenericQueuePair {
unsafe fn qp(&self) -> NonNull<ibv_qp> {
match self {
GenericQueuePair::Basic(qp) => qp.qp(),
GenericQueuePair::Extended(qp) => qp.qp(),
}
}
fn qp_number(&self) -> u32 {
match self {
GenericQueuePair::Basic(qp) => qp.qp_number(),
GenericQueuePair::Extended(qp) => qp.qp_number(),
}
}
fn modify(&mut self, attr: &QueuePairAttribute) -> Result<(), ModifyQueuePairError> {
match self {
GenericQueuePair::Basic(qp) => qp.modify(attr),
GenericQueuePair::Extended(qp) => qp.modify(attr),
}
}
fn start_post_recv(&mut self) -> PostRecvGuard<'_> {
match self {
GenericQueuePair::Basic(qp) => qp.start_post_recv(),
GenericQueuePair::Extended(qp) => qp.start_post_recv(),
}
}
type Guard<'g>
= GenericPostSendGuard<'g>
where
Self: 'g;
fn start_post_send(&mut self) -> Self::Guard<'_> {
match self {
GenericQueuePair::Basic(qp) => GenericPostSendGuard::Basic(qp.start_post_send()),
GenericQueuePair::Extended(qp) => GenericPostSendGuard::Extended(qp.start_post_send()),
}
}
}
pub enum GenericPostSendGuard<'g> {
Basic(BasicPostSendGuard<'g>),
Extended(ExtendedPostSendGuard<'g>),
}
impl PostSendGuard for GenericPostSendGuard<'_> {
fn construct_wr(&mut self, wr_id: u64, wr_flags: WorkRequestFlags) -> WorkRequestHandle<'_, Self> {
match self {
GenericPostSendGuard::Basic(guard) => {
guard.construct_wr(wr_id, wr_flags);
WorkRequestHandle { guard: self }
},
GenericPostSendGuard::Extended(guard) => {
guard.construct_wr(wr_id, wr_flags);
WorkRequestHandle { guard: self }
},
}
}
fn post(self) -> Result<(), PostSendError> {
match self {
GenericPostSendGuard::Basic(guard) => guard.post(),
GenericPostSendGuard::Extended(guard) => guard.post(),
}
}
}
impl private_traits::PostSendGuard for GenericPostSendGuard<'_> {
fn setup_send(&mut self) {
match self {
GenericPostSendGuard::Basic(guard) => guard.setup_send(),
GenericPostSendGuard::Extended(guard) => guard.setup_send(),
}
}
fn setup_send_imm(&mut self, imm_data: u32) {
match self {
GenericPostSendGuard::Basic(guard) => guard.setup_send_imm(imm_data),
GenericPostSendGuard::Extended(guard) => guard.setup_send_imm(imm_data),
}
}
fn setup_write(&mut self, rkey: u32, remote_addr: u64) {
match self {
GenericPostSendGuard::Basic(guard) => guard.setup_write(rkey, remote_addr),
GenericPostSendGuard::Extended(guard) => guard.setup_write(rkey, remote_addr),
}
}
fn setup_write_imm(&mut self, rkey: u32, remote_addr: u64, imm_data: u32) {
match self {
GenericPostSendGuard::Basic(guard) => guard.setup_write_imm(rkey, remote_addr, imm_data),
GenericPostSendGuard::Extended(guard) => guard.setup_write_imm(rkey, remote_addr, imm_data),
}
}
fn setup_read(&mut self, rkey: u32, remote_addr: u64) {
match self {
GenericPostSendGuard::Basic(guard) => guard.setup_read(rkey, remote_addr),
GenericPostSendGuard::Extended(guard) => guard.setup_read(rkey, remote_addr),
}
}
fn setup_inline_data(&mut self, buf: &[u8]) {
match self {
GenericPostSendGuard::Basic(guard) => guard.setup_inline_data(buf),
GenericPostSendGuard::Extended(guard) => guard.setup_inline_data(buf),
}
}
fn setup_inline_data_list(&mut self, bufs: &[IoSlice<'_>]) {
match self {
GenericPostSendGuard::Basic(guard) => guard.setup_inline_data_list(bufs),
GenericPostSendGuard::Extended(guard) => guard.setup_inline_data_list(bufs),
}
}
unsafe fn setup_sge(&mut self, lkey: u32, addr: u64, length: u32) {
match self {
GenericPostSendGuard::Basic(guard) => guard.setup_sge(lkey, addr, length),
GenericPostSendGuard::Extended(guard) => guard.setup_sge(lkey, addr, length),
}
}
unsafe fn setup_sge_list(&mut self, sg_list: &[ibv_sge]) {
match self {
GenericPostSendGuard::Basic(guard) => guard.setup_sge_list(sg_list),
GenericPostSendGuard::Extended(guard) => guard.setup_sge_list(sg_list),
}
}
}
impl From<BasicQueuePair> for GenericQueuePair {
fn from(qp: BasicQueuePair) -> Self {
GenericQueuePair::Basic(qp)
}
}
impl From<ExtendedQueuePair> for GenericQueuePair {
fn from(qp: ExtendedQueuePair) -> Self {
GenericQueuePair::Extended(qp)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ibverbs::address::GidType;
use crate::ibverbs::completion::GenericCompletionQueue;
use crate::ibverbs::device;
#[test]
fn test_query_qp() -> Result<(), Box<dyn std::error::Error>> {
let device_list = device::DeviceList::new()?;
match device_list.get(0) {
Some(device) => {
let ctx = device.open()?;
let pd = ctx.alloc_pd()?;
let memory = [1, 2, 3, 4];
let mr_handle = memory.as_ptr() as usize;
let mr = unsafe {
pd.reg_mr(mr_handle, 16, AccessFlags::LocalWrite | AccessFlags::RemoteWrite)
.unwrap()
};
let cq = GenericCompletionQueue::from(ctx.create_cq_builder().setup_cqe(2).build_ex()?);
let mut qp = pd
.create_qp_builder()
.setup_send_cq(cq.clone())
.setup_recv_cq(cq.clone())
.build()?;
let mut guard = qp.start_post_recv();
unsafe {
let handle = guard.construct_wr(1);
handle.setup_sge(mr.lkey(), mr.get_ptr() as u64, 1);
match guard.post() {
Err(PostRecvError::InvalidWorkRequest(_)) => {},
other => panic!("Expected InvalidWorkRequest error, got: {other:?}"),
}
}
let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::Init)
.setup_pkey_index(0)
.setup_port(1)
.setup_access_flags(AccessFlags::RemoteWrite);
qp.modify(&attr)?;
let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::ReadyToReceive)
.setup_path_mtu(Mtu::Mtu1024)
.setup_dest_qp_num(1024)
.setup_rq_psn(1024)
.setup_max_dest_read_atomic(0)
.setup_min_rnr_timer(0);
let mut ah_attr = AddressHandleAttribute::new();
let gid_entries = ctx.query_gid_table().unwrap();
let gid = gid_entries
.iter()
.find(|&&gid| !gid.gid().is_unicast_link_local() || gid.gid_type() == GidType::RoceV1)
.unwrap();
ah_attr
.setup_dest_lid(1)
.setup_port(1)
.setup_service_level(1)
.setup_grh_src_gid_index(gid.gid_index().try_into().unwrap())
.setup_grh_dest_gid(&gid.gid())
.setup_grh_hop_limit(64);
attr.setup_address_vector(&ah_attr);
qp.modify(&attr)?;
let mask = QueuePairAttributeMask::AccessFlags
| QueuePairAttributeMask::PathMtu
| QueuePairAttributeMask::DestinationQueuePairNumber
| QueuePairAttributeMask::Port;
let (attr, init_attr) = qp.query(mask)?;
assert_eq!(attr.access_flags(), AccessFlags::RemoteWrite);
assert_eq!(attr.dest_qp_num(), 1024);
assert_eq!(attr.path_mtu(), Mtu::Mtu1024);
assert_eq!(attr.port(), 1);
assert!(init_attr.max_send_wr() >= 16);
assert!(init_attr.max_recv_wr() >= 16);
assert!(init_attr.max_send_sge() >= 1);
assert!(init_attr.max_recv_sge() >= 1);
Ok(())
},
None => Ok(()),
}
}
#[test]
fn test_post_recv_errors() -> Result<(), Box<dyn std::error::Error>> {
let device_list = device::DeviceList::new()?;
match device_list.get(0) {
Some(device) => {
let ctx = device.open()?;
let pd = ctx.alloc_pd()?;
let memory = [1, 2, 3, 4];
let mr_handle = memory.as_ptr() as usize;
let mr = unsafe {
pd.reg_mr(mr_handle, 16, AccessFlags::LocalWrite | AccessFlags::RemoteWrite)
.unwrap()
};
let cq = GenericCompletionQueue::from(ctx.create_cq_builder().setup_cqe(2).build_ex()?);
let mut qp = pd
.create_qp_builder()
.setup_send_cq(cq.clone())
.setup_recv_cq(cq.clone())
.setup_max_recv_wr(1)
.build()?;
let mut guard = qp.start_post_recv();
unsafe {
let handle = guard.construct_wr(1);
handle.setup_sge(mr.lkey(), mr.get_ptr() as u64, 1);
match guard.post() {
Err(PostRecvError::InvalidWorkRequest(_)) => {},
other => panic!("Expected InvalidWorkRequest error, got: {other:?}"),
}
}
let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::Init)
.setup_pkey_index(0)
.setup_port(1)
.setup_access_flags(AccessFlags::RemoteWrite);
qp.modify(&attr)?;
let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::ReadyToReceive)
.setup_path_mtu(Mtu::Mtu1024)
.setup_dest_qp_num(1024)
.setup_rq_psn(1024)
.setup_max_dest_read_atomic(0)
.setup_min_rnr_timer(0);
let mut ah_attr = AddressHandleAttribute::new();
let gid_entries = ctx.query_gid_table().unwrap();
let gid = gid_entries
.iter()
.find(|&&gid| !gid.gid().is_unicast_link_local() || gid.gid_type() == GidType::RoceV1)
.unwrap();
ah_attr
.setup_dest_lid(1)
.setup_port(1)
.setup_service_level(1)
.setup_grh_src_gid_index(gid.gid_index().try_into().unwrap())
.setup_grh_dest_gid(&gid.gid())
.setup_grh_hop_limit(64);
attr.setup_address_vector(&ah_attr);
qp.modify(&attr)?;
let mut guard = qp.start_post_recv();
for i in 0..128 {
unsafe {
let handle = guard.construct_wr(i);
handle.setup_sge(mr.lkey(), mr.get_ptr() as u64, 1);
}
}
match guard.post() {
Err(PostRecvError::NotEnoughResources(_)) => {},
other => panic!("Expected NotEnoughResources error, got: {other:?}"),
}
Ok(())
},
None => Ok(()),
}
}
}