#![deny(
// The following are allowed by default lints according to
// https://doc.rust-lang.org/rustc/lints/listing/allowed-by-default.html
anonymous_parameters,
bare_trait_objects,
// box_pointers, // use box pointer to allocate on heap
// elided_lifetimes_in_paths, // allow anonymous lifetime
missing_copy_implementations,
missing_debug_implementations,
missing_docs, // TODO: add documents
single_use_lifetimes, // TODO: fix lifetime names only used once
trivial_casts, // TODO: remove trivial casts in code
trivial_numeric_casts,
// unreachable_pub, allow clippy::redundant_pub_crate lint instead
// unsafe_code,
unstable_features,
unused_extern_crates,
unused_import_braces,
unused_qualifications,
unused_results,
variant_size_differences,
warnings, // treat all wanings as errors
clippy::all,
clippy::restriction,
clippy::pedantic,
// clippy::nursery, // It's still under development
clippy::cargo,
unreachable_pub,
)]
#![allow(
// Some explicitly allowed Clippy lints, must have clear reason to allow
clippy::blanket_clippy_restriction_lints, // allow clippy::restriction
clippy::implicit_return, // actually omitting the return keyword is idiomatic Rust code
clippy::module_name_repetitions, // repeation of module name in a struct name is not big deal
clippy::multiple_crate_versions, // multi-version dependency crates is not able to fix
clippy::missing_errors_doc, // TODO: add error docs
clippy::missing_panics_doc, // TODO: add panic docs
clippy::panic_in_result_fn,
clippy::shadow_same, // Not too much bad
clippy::shadow_reuse, // Not too much bad
clippy::exhaustive_enums,
clippy::exhaustive_structs,
clippy::indexing_slicing,
clippy::separated_literal_suffix, // conflicts with clippy::unseparated_literal_suffix
clippy::single_char_lifetime_names, // TODO: change lifetime names
)]
mod agent;
mod completion_queue;
mod context;
pub mod device;
mod access;
mod error_utilities;
mod event_channel;
mod event_listener;
mod gid;
mod hashmap_extension;
mod id;
mod lock_utilities;
mod memory_region;
mod memory_window;
mod mr_allocator;
mod protection_domain;
mod queue_pair;
mod rmr_manager;
mod work_request;
use access::flags_into_ibv_access;
pub use access::AccessFlag;
use agent::{Agent, MAX_MSG_LEN};
use clippy_utilities::Cast;
use completion_queue::{DEFAULT_CQ_SIZE, DEFAULT_MAX_CQE};
use context::Context;
use enumflags2::BitFlags;
use error_utilities::log_ret_last_os_err;
use event_listener::EventListener;
pub use memory_region::{
local::{LocalMr, LocalMrReadAccess, LocalMrWriteAccess},
remote::{RemoteMr, RemoteMrReadAccess, RemoteMrWriteAccess},
MrAccess,
};
pub use mr_allocator::MRManageStrategy;
use mr_allocator::MrAllocator;
use protection_domain::ProtectionDomain;
use queue_pair::{
QueuePair, QueuePairEndpoint, MAX_RECV_SGE, MAX_RECV_WR, MAX_SEND_SGE, MAX_SEND_WR,
};
use rdma_sys::ibv_access_flags;
#[cfg(feature = "cm")]
use rdma_sys::{
rdma_addrinfo, rdma_cm_id, rdma_connect, rdma_create_ep, rdma_disconnect, rdma_freeaddrinfo,
rdma_getaddrinfo, rdma_port_space,
};
use rmr_manager::DEFAULT_RMR_TIMEOUT;
#[cfg(feature = "cm")]
use std::ptr::null_mut;
use std::{alloc::Layout, fmt::Debug, io, ptr::NonNull, sync::Arc, time::Duration};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream, ToSocketAddrs},
sync::Mutex,
};
use tracing::debug;
#[macro_use]
extern crate lazy_static;
#[derive(Debug)]
pub struct DeviceInitAttr {
dev_name: Option<String>,
port_num: u8,
gid_index: usize,
}
impl Default for DeviceInitAttr {
#[inline]
fn default() -> Self {
Self {
dev_name: None,
port_num: 1,
gid_index: 1,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct CQInitAttr {
cq_size: u32,
max_cqe: i32,
}
impl Default for CQInitAttr {
#[inline]
fn default() -> Self {
Self {
cq_size: DEFAULT_CQ_SIZE,
max_cqe: DEFAULT_MAX_CQE,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct QPInitAttr {
access: ibv_access_flags,
conn_type: ConnectionType,
raw: bool,
max_send_wr: u32,
max_recv_wr: u32,
max_send_sge: u32,
max_recv_sge: u32,
}
lazy_static! {
static ref DEFAULT_ACCESS:ibv_access_flags = ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
| ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
| ibv_access_flags::IBV_ACCESS_REMOTE_READ
| ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
}
impl Default for QPInitAttr {
#[inline]
fn default() -> Self {
Self {
max_send_wr: MAX_SEND_WR,
max_recv_wr: MAX_RECV_WR,
max_send_sge: MAX_SEND_SGE,
max_recv_sge: MAX_RECV_SGE,
access: *DEFAULT_ACCESS,
conn_type: ConnectionType::RCSocket,
raw: false,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct MRInitAttr {
access: ibv_access_flags,
strategy: MRManageStrategy,
}
impl Default for MRInitAttr {
#[inline]
fn default() -> Self {
Self {
access: *DEFAULT_ACCESS,
strategy: MRManageStrategy::Jemalloc,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct AgentInitAttr {
max_message_length: usize,
max_rmr_access: ibv_access_flags,
}
impl Default for AgentInitAttr {
#[inline]
fn default() -> Self {
Self {
max_message_length: MAX_MSG_LEN,
max_rmr_access: *DEFAULT_ACCESS,
}
}
}
#[derive(Default)]
pub struct RdmaBuilder {
dev_attr: DeviceInitAttr,
cq_attr: CQInitAttr,
qp_attr: QPInitAttr,
mr_attr: MRInitAttr,
agent_attr: AgentInitAttr,
}
impl RdmaBuilder {
#[must_use]
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn build(&self) -> io::Result<Rdma> {
Rdma::new(&self.dev_attr, self.cq_attr, self.qp_attr, self.mr_attr)
}
#[inline]
pub async fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<Rdma> {
match self.qp_attr.conn_type {
ConnectionType::RCSocket => {
let mut rdma = self.build()?;
let remote = tcp_connect_helper(addr, &rdma.endpoint()).await?;
rdma.qp_handshake(remote)?;
rdma.init_agent(
self.agent_attr.max_message_length,
self.agent_attr.max_rmr_access,
)
.await?;
Ok(rdma)
}
ConnectionType::RCCM => Err(io::Error::new(
io::ErrorKind::Other,
"ConnectionType should be XXSocket",
)),
}
}
#[inline]
#[cfg(feature = "cm")]
pub async fn cm_connect(self, node: &str, service: &str) -> io::Result<Rdma> {
match self.qp_attr.conn_type {
ConnectionType::RCSocket => Err(io::Error::new(
io::ErrorKind::Other,
"ConnectionType should be XXSocket",
)),
ConnectionType::RCCM => {
let max_message_length = self.agent_attr.max_message_length;
let max_rmr_access = self.agent_attr.max_rmr_access;
let mut rdma = self.build()?;
cm_connect_helper(&mut rdma, node, service)?;
rdma.init_agent(max_message_length, max_rmr_access).await?;
Ok(rdma)
}
}
}
#[inline]
pub async fn listen<A: ToSocketAddrs>(self, addr: A) -> io::Result<Rdma> {
match self.qp_attr.conn_type {
ConnectionType::RCSocket => {
let mut rdma = self.build()?;
let tcp_listener = TcpListener::bind(addr).await?;
let remote = tcp_listen(&tcp_listener, &rdma.endpoint()).await?;
rdma.qp_handshake(remote)?;
debug!("handshake done");
rdma.init_agent(
self.agent_attr.max_message_length,
self.agent_attr.max_rmr_access,
)
.await?;
rdma.clone_attr = CloneAttr::default().set_tcp_listener(tcp_listener);
Ok(rdma)
}
ConnectionType::RCCM => Err(io::Error::new(
io::ErrorKind::Other,
"ConnectionType should be XXSocket",
)),
}
}
#[inline]
#[must_use]
pub fn set_dev(mut self, dev: &str) -> Self {
self.dev_attr.dev_name = Some(dev.to_owned());
self
}
#[inline]
#[must_use]
pub fn set_cq_size(mut self, cq_size: u32) -> Self {
self.cq_attr.cq_size = cq_size;
self
}
#[inline]
#[must_use]
pub fn set_gid_index(mut self, gid_index: usize) -> Self {
self.dev_attr.gid_index = gid_index;
self
}
#[inline]
#[must_use]
pub fn set_port_num(mut self, port_num: u8) -> Self {
self.dev_attr.port_num = port_num;
self
}
#[inline]
#[must_use]
pub fn set_conn_type(mut self, conn_type: ConnectionType) -> Self {
self.qp_attr.conn_type = conn_type;
self
}
#[inline]
#[must_use]
pub fn set_raw(mut self, raw: bool) -> Self {
self.qp_attr.raw = raw;
self
}
#[inline]
#[must_use]
pub fn set_qp_max_send_wr(mut self, max_send_wr: u32) -> Self {
self.qp_attr.max_send_wr = max_send_wr;
self
}
#[inline]
#[must_use]
pub fn set_qp_max_recv_wr(mut self, max_recv_wr: u32) -> Self {
self.qp_attr.max_recv_wr = max_recv_wr;
self
}
#[inline]
#[must_use]
pub fn set_qp_max_send_sge(mut self, max_send_sge: u32) -> Self {
self.qp_attr.max_send_sge = max_send_sge;
self
}
#[inline]
#[must_use]
pub fn set_qp_max_recv_sge(mut self, max_recv_sge: u32) -> Self {
self.qp_attr.max_recv_sge = max_recv_sge;
self
}
#[inline]
#[must_use]
pub fn set_qp_access(mut self, flags: BitFlags<AccessFlag>) -> Self {
self.qp_attr.access = flags_into_ibv_access(flags);
self
}
#[inline]
#[must_use]
pub fn set_mr_access(mut self, flags: BitFlags<AccessFlag>) -> Self {
self.mr_attr.access = flags_into_ibv_access(flags);
self
}
#[inline]
#[must_use]
pub fn set_mr_strategy(mut self, strategy: MRManageStrategy) -> Self {
self.mr_attr.strategy = strategy;
self
}
#[inline]
#[must_use]
pub fn set_max_message_length(mut self, max_msg_len: usize) -> Self {
self.agent_attr.max_message_length = max_msg_len;
self
}
#[inline]
#[must_use]
pub fn set_max_rmr_access(mut self, flags: BitFlags<AccessFlag>) -> Self {
self.agent_attr.max_rmr_access = flags_into_ibv_access(flags);
self
}
}
impl Debug for RdmaBuilder {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RdmaBuilder")
.field("dev_name", &self.dev_attr.dev_name)
.field("cq_size", &self.cq_attr.cq_size)
.finish()
}
}
async fn tcp_connect_helper<A: ToSocketAddrs>(
addr: A,
ep: &QueuePairEndpoint,
) -> io::Result<QueuePairEndpoint> {
let mut stream = TcpStream::connect(addr).await?;
let mut endpoint = bincode::serialize(ep).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("failed to serailize the endpoint, {:?}", e),
)
})?;
stream.write_all(&endpoint).await?;
let _ = stream.read_exact(endpoint.as_mut()).await?;
bincode::deserialize(&endpoint).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("failed to deserailize the endpoint, {:?}", e),
)
})
}
async fn tcp_listen(
tcp_listener: &TcpListener,
ep: &QueuePairEndpoint,
) -> io::Result<QueuePairEndpoint> {
let (mut stream, _) = tcp_listener.accept().await?;
let endpoint_size = bincode::serialized_size(ep).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Endpoint serialization failed, {:?}", e),
)
})?;
let mut remote = vec![0_u8; endpoint_size.cast()];
let _ = stream.read_exact(remote.as_mut()).await?;
let remote: QueuePairEndpoint = bincode::deserialize(&remote).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("failed to deserialize remote endpoint, {:?}", e),
)
})?;
let local = bincode::serialize(ep).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("failed to deserialize remote endpoint, {:?}", e),
)
})?;
stream.write_all(&local).await?;
Ok(remote)
}
#[inline]
#[cfg(feature = "cm")]
fn cm_connect_helper(rdma: &mut Rdma, node: &str, service: &str) -> io::Result<()> {
let mut hints = unsafe { std::mem::zeroed::<rdma_addrinfo>() };
let mut info: *mut rdma_addrinfo = null_mut();
hints.ai_port_space = rdma_port_space::RDMA_PS_TCP.cast();
let mut ret = unsafe {
rdma_getaddrinfo(
node.as_ptr().cast(),
service.as_ptr().cast(),
&hints,
&mut info,
)
};
if ret != 0_i32 {
return Err(log_ret_last_os_err());
}
let mut id: *mut rdma_cm_id = null_mut();
ret = unsafe { rdma_create_ep(&mut id, info, rdma.pd.as_ptr(), null_mut()) };
if ret != 0_i32 {
unsafe {
rdma_freeaddrinfo(info);
}
return Err(log_ret_last_os_err());
}
unsafe {
debug!(
"cm_id: {:?},{:?},{:?},{:?},{:?},{:?},{:?}",
(*id).qp,
(*id).pd,
(*id).verbs,
(*id).recv_cq_channel,
(*id).send_cq_channel,
(*id).recv_cq,
(*id).send_cq
);
(*id).qp = rdma.qp.as_ptr();
(*id).pd = rdma.pd.as_ptr();
(*id).verbs = rdma.ctx.as_ptr();
(*id).recv_cq_channel = rdma.qp.event_listener.cq.event_channel().as_ptr();
(*id).recv_cq_channel = rdma.qp.event_listener.cq.event_channel().as_ptr();
(*id).recv_cq = rdma.qp.event_listener.cq.as_ptr();
(*id).send_cq = rdma.qp.event_listener.cq.as_ptr();
debug!(
"cm_id: {:?},{:?},{:?},{:?},{:?},{:?},{:?}",
(*id).qp,
(*id).pd,
(*id).verbs,
(*id).recv_cq_channel,
(*id).send_cq_channel,
(*id).recv_cq,
(*id).send_cq
);
}
ret = unsafe { rdma_connect(id, null_mut()) };
if ret != 0_i32 {
unsafe {
let _ = rdma_disconnect(id);
}
return Err(log_ret_last_os_err());
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionType {
RCSocket,
RCCM,
}
#[derive(Debug, Clone, Default)]
pub(crate) struct CloneAttr {
pub(crate) tcp_listener: Option<Arc<Mutex<TcpListener>>>,
pub(crate) pd: Option<Arc<ProtectionDomain>>,
pub(crate) port_num: Option<u8>,
pub(crate) qp_access: Option<ibv_access_flags>,
pub(crate) max_rmr_access: Option<ibv_access_flags>,
}
impl CloneAttr {
fn set_tcp_listener(mut self, tcp_listener: TcpListener) -> Self {
self.tcp_listener = Some(Arc::new(Mutex::new(tcp_listener)));
self
}
#[allow(dead_code)] fn set_pd(mut self, pd: ProtectionDomain) -> Self {
self.pd = Some(Arc::new(pd));
self
}
fn set_port_num(mut self, port_num: u8) -> Self {
self.port_num = Some(port_num);
self
}
fn set_qp_access(mut self, access: ibv_access_flags) -> Self {
self.qp_access = Some(access);
self
}
fn set_max_rmr_access(mut self, access: ibv_access_flags) -> Self {
self.max_rmr_access = Some(access);
self
}
}
#[derive(Debug)]
pub struct Rdma {
ctx: Arc<Context>,
pd: Arc<ProtectionDomain>,
allocator: Arc<MrAllocator>,
qp: Arc<QueuePair>,
agent: Option<Arc<Agent>>,
conn_type: ConnectionType,
raw: bool,
clone_attr: CloneAttr,
}
impl Rdma {
fn new(
dev_attr: &DeviceInitAttr,
cq_attr: CQInitAttr,
qp_attr: QPInitAttr,
mr_attr: MRInitAttr,
) -> io::Result<Self> {
let ctx = Arc::new(Context::open(
dev_attr.dev_name.as_deref(),
dev_attr.port_num,
dev_attr.gid_index,
)?);
let ec = ctx.create_event_channel()?;
let cq = Arc::new(ctx.create_completion_queue(cq_attr.cq_size, ec, cq_attr.max_cqe)?);
let event_listener = EventListener::new(cq);
let pd = Arc::new(ctx.create_protection_domain()?);
let allocator = Arc::new(MrAllocator::new(
Arc::<ProtectionDomain>::clone(&pd),
mr_attr,
));
let mut qp = pd
.create_queue_pair_builder()
.set_event_listener(event_listener)
.set_port_num(dev_attr.port_num)
.set_gid_index(dev_attr.gid_index)
.set_max_send_wr(qp_attr.max_send_wr)
.set_max_send_sge(qp_attr.max_send_sge)
.set_max_recv_wr(qp_attr.max_recv_wr)
.set_max_recv_sge(qp_attr.max_recv_sge)
.build()?;
qp.modify_to_init(qp_attr.access, dev_attr.port_num)?;
Ok(Self {
ctx,
pd,
qp: Arc::new(qp),
agent: None,
allocator,
conn_type: qp_attr.conn_type,
raw: qp_attr.raw,
clone_attr: CloneAttr::default(),
})
}
fn clone(&self) -> io::Result<Self> {
let qp_access = self.clone_attr.qp_access.map_or_else(
|| {
self.qp.access.map_or_else(
|| {
Err(io::Error::new(
io::ErrorKind::Other,
"parent qp access is none",
))
},
Ok,
)
},
Ok,
)?;
let port_num = self.clone_attr.port_num.unwrap_or(self.qp.port_num);
let pd = self
.clone_attr
.pd
.as_ref()
.map_or_else(|| &self.pd, |pd| pd);
let mut qp_init_attr = self.qp.qp_init_attr.clone();
let inner_qp = NonNull::new(unsafe {
rdma_sys::ibv_create_qp(self.pd.as_ptr(), &mut qp_init_attr.qp_init_attr_inner)
})
.ok_or_else(log_ret_last_os_err)?;
let mut qp = QueuePair {
pd: Arc::clone(pd),
event_listener: Arc::clone(&self.qp.event_listener),
inner_qp,
port_num,
gid_index: self.qp.gid_index,
qp_init_attr,
access: Some(qp_access),
};
qp.modify_to_init(qp_access, self.qp.port_num)?;
Ok(Self {
ctx: Arc::clone(&self.ctx),
pd: Arc::clone(pd),
qp: Arc::new(qp),
agent: None,
allocator: Arc::clone(&self.allocator),
conn_type: self.conn_type,
raw: self.raw,
clone_attr: self.clone_attr.clone(),
})
}
fn endpoint(&self) -> QueuePairEndpoint {
self.qp.endpoint()
}
fn qp_handshake(&mut self, remote: QueuePairEndpoint) -> io::Result<()> {
self.qp.modify_to_rtr(remote, 0, 1, 0x12)?;
debug!("rtr");
self.qp.modify_to_rts(0x12, 6, 6, 0, 1)?;
debug!("rts");
Ok(())
}
async fn init_agent(
&mut self,
max_message_length: usize,
max_rmr_access: ibv_access_flags,
) -> io::Result<()> {
if !self.raw {
let agent = Arc::new(Agent::new(
Arc::<QueuePair>::clone(&self.qp),
Arc::<MrAllocator>::clone(&self.allocator),
max_message_length,
max_rmr_access,
)?);
self.agent = Some(agent);
tokio::time::sleep(Duration::from_secs(1)).await;
}
Ok(())
}
#[inline]
pub async fn listen(&self) -> io::Result<Self> {
match self.conn_type {
ConnectionType::RCSocket => {
let mut rdma = self.clone()?;
let remote = self
.clone_attr
.tcp_listener
.as_ref()
.map_or_else(
|| Err(io::Error::new(io::ErrorKind::Other, "tcp_listener is None")),
|tcp_listener| {
Ok(async {
let tcp_listener = tcp_listener.lock().await;
tcp_listen(&tcp_listener, &rdma.endpoint()).await
})
},
)?
.await?;
rdma.qp_handshake(remote)?;
debug!("handshake done");
#[allow(clippy::unreachable)]
let (max_message_length, max_rmr_access) = self.agent.as_ref().map_or_else(
|| {
unreachable!("agent of parent rdma is None");
},
|agent| (agent.max_msg_len(), agent.max_rmr_access()),
);
let max_rmr_access = self
.clone_attr
.max_rmr_access
.map_or(max_rmr_access, |new_access| new_access);
rdma.init_agent(max_message_length, max_rmr_access).await?;
Ok(rdma)
}
ConnectionType::RCCM => Err(io::Error::new(
io::ErrorKind::Other,
"ConnectionType should be XXSocket",
)),
}
}
#[inline]
pub async fn new_connect<A: ToSocketAddrs>(&self, addr: A) -> io::Result<Self> {
match self.conn_type {
ConnectionType::RCSocket => {
let mut rdma = self.clone()?;
let remote = tcp_connect_helper(addr, &rdma.endpoint()).await?;
rdma.qp_handshake(remote)?;
#[allow(clippy::unreachable)]
let (max_message_length, max_rmr_access) = self.agent.as_ref().map_or_else(
|| {
unreachable!("agent of parent rdma is None");
},
|agent| (agent.max_msg_len(), agent.max_rmr_access()),
);
let max_rmr_access = self
.clone_attr
.max_rmr_access
.map_or(max_rmr_access, |new_access| new_access);
rdma.init_agent(max_message_length, max_rmr_access).await?;
Ok(rdma)
}
ConnectionType::RCCM => Err(io::Error::new(
io::ErrorKind::Other,
"ConnectionType should be XXSocket",
)),
}
}
#[inline]
pub async fn send(&self, lm: &LocalMr) -> io::Result<()> {
self.agent
.as_ref()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Agent is not ready"))?
.send_data(lm, None)
.await
}
#[inline]
#[cfg(feature = "raw")]
pub async fn send_raw(&self, lm: &LocalMr) -> io::Result<()> {
self.qp.send_sge_raw(&[lm], None).await
}
#[inline]
#[cfg(feature = "raw")]
pub async fn send_raw_with_imm(&self, lm: &LocalMr, imm: u32) -> io::Result<()> {
self.qp.send_sge_raw(&[lm], Some(imm)).await
}
#[inline]
pub async fn send_with_imm(&self, lm: &LocalMr, imm: u32) -> io::Result<()> {
self.agent
.as_ref()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Agent is not ready"))?
.send_data(lm, Some(imm))
.await
}
#[inline]
pub async fn receive(&self) -> io::Result<LocalMr> {
let (lmr, _) = self.receive_with_imm().await?;
Ok(lmr)
}
#[inline]
#[cfg(feature = "raw")]
pub async fn receive_raw(&self, layout: Layout) -> io::Result<LocalMr> {
let mut lmr = self.alloc_local_mr(layout)?;
let _imm = self.qp.receive_sge_raw(&[&mut lmr]).await?;
Ok(lmr)
}
#[inline]
#[cfg(feature = "raw")]
pub async fn receive_raw_with_imm(&self, layout: Layout) -> io::Result<(LocalMr, Option<u32>)> {
let mut lmr = self.alloc_local_mr(layout)?;
let imm = self.qp.receive_sge_raw(&[&mut lmr]).await?;
Ok((lmr, imm))
}
#[inline]
pub async fn receive_with_imm(&self) -> io::Result<(LocalMr, Option<u32>)> {
self.agent
.as_ref()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Agent is not ready"))?
.receive_data()
.await
}
#[inline]
pub async fn receive_write_imm(&self) -> io::Result<u32> {
self.agent
.as_ref()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Agent is not ready"))?
.receive_imm()
.await
}
#[inline]
pub async fn read<LW, RR>(&self, lm: &mut LW, rm: &RR) -> io::Result<()>
where
LW: LocalMrWriteAccess,
RR: RemoteMrReadAccess,
{
self.qp.read(lm, rm).await
}
#[inline]
pub async fn write<LR, RW>(&self, lm: &LR, rm: &mut RW) -> io::Result<()>
where
LR: LocalMrReadAccess,
RW: RemoteMrWriteAccess,
{
self.qp.write(lm, rm, None).await
}
#[inline]
pub async fn write_with_imm<LR, RW>(&self, lm: &LR, rm: &mut RW, imm: u32) -> io::Result<()>
where
LR: LocalMrReadAccess,
RW: RemoteMrWriteAccess,
{
self.qp.write(lm, rm, Some(imm)).await
}
#[inline]
pub async fn connect<A: ToSocketAddrs>(
addr: A,
port_num: u8,
gid_index: usize,
max_message_length: usize,
) -> io::Result<Self> {
let mut rdma = RdmaBuilder::default()
.set_port_num(port_num)
.set_gid_index(gid_index)
.build()?;
assert!(
rdma.conn_type == ConnectionType::RCSocket,
"should set connection type to RCSocket"
);
let remote = tcp_connect_helper(addr, &rdma.endpoint()).await?;
rdma.qp_handshake(remote)?;
rdma.init_agent(max_message_length, *DEFAULT_ACCESS).await?;
tokio::time::sleep(Duration::from_secs(1)).await;
Ok(rdma)
}
#[inline]
#[cfg(feature = "cm")]
pub async fn cm_connect(
node: &str,
service: &str,
port_num: u8,
gid_index: usize,
max_message_length: usize,
) -> io::Result<Self> {
let mut rdma = RdmaBuilder::default()
.set_port_num(port_num)
.set_gid_index(gid_index)
.set_raw(true)
.set_conn_type(ConnectionType::RCCM)
.build()?;
assert!(
rdma.conn_type == ConnectionType::RCCM,
"should set connection type to RCSocket"
);
cm_connect_helper(&mut rdma, node, service)?;
rdma.init_agent(max_message_length, *DEFAULT_ACCESS).await?;
tokio::time::sleep(Duration::from_secs(1)).await;
Ok(rdma)
}
#[inline]
pub fn alloc_local_mr(&self, layout: Layout) -> io::Result<LocalMr> {
self.allocator
.alloc_zeroed_default_access(&layout, &self.pd)
}
#[inline]
pub unsafe fn alloc_local_mr_uninit(&self, layout: Layout) -> io::Result<LocalMr> {
self.allocator.alloc_default_access(&layout, &self.pd)
}
#[inline]
pub fn alloc_local_mr_with_access(
&self,
layout: Layout,
access: BitFlags<AccessFlag>,
) -> io::Result<LocalMr> {
self.allocator
.alloc_zeroed(&layout, flags_into_ibv_access(access), &self.pd)
}
#[inline]
pub unsafe fn alloc_local_mr_uninit_with_access(
&self,
layout: Layout,
access: BitFlags<AccessFlag>,
) -> io::Result<LocalMr> {
self.allocator
.alloc(&layout, flags_into_ibv_access(access), &self.pd)
}
#[inline]
pub async fn request_remote_mr(&self, layout: Layout) -> io::Result<RemoteMr> {
self.request_remote_mr_with_timeout(layout, DEFAULT_RMR_TIMEOUT)
.await
}
#[inline]
pub async fn request_remote_mr_with_timeout(
&self,
layout: Layout,
timeout: Duration,
) -> io::Result<RemoteMr> {
if let Some(ref agent) = self.agent {
agent.request_remote_mr_with_timeout(layout, timeout).await
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Agent is not ready, please wait a while",
))
}
}
#[inline]
pub async fn send_local_mr(&self, mr: LocalMr) -> io::Result<()> {
self.send_local_mr_with_timeout(mr, DEFAULT_RMR_TIMEOUT)
.await
}
#[inline]
pub async fn send_local_mr_with_timeout(
&self,
mr: LocalMr,
timeout: Duration,
) -> io::Result<()> {
if let Some(ref agent) = self.agent {
agent.send_local_mr_with_timeout(mr, timeout).await
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Agent is not ready, please wait a while",
))
}
}
#[inline]
pub async fn send_remote_mr(&self, mr: RemoteMr) -> io::Result<()> {
if let Some(ref agent) = self.agent {
agent.send_remote_mr(mr).await
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Agent is not ready, please wait a while",
))
}
}
#[inline]
pub async fn receive_local_mr(&self) -> io::Result<LocalMr> {
if let Some(ref agent) = self.agent {
agent.receive_local_mr().await
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Agent is not ready, please wait a while",
))
}
}
#[inline]
pub async fn receive_remote_mr(&self) -> io::Result<RemoteMr> {
if let Some(ref agent) = self.agent {
agent.receive_remote_mr().await
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Agent is not ready, please wait a while",
))
}
}
#[inline]
#[must_use]
pub fn set_new_qp_access(mut self, qp_access: BitFlags<AccessFlag>) -> Self {
self.clone_attr = self
.clone_attr
.set_qp_access(flags_into_ibv_access(qp_access));
self
}
#[inline]
#[must_use]
pub fn set_new_max_rmr_access(mut self, max_rmr_access: BitFlags<AccessFlag>) -> Self {
self.clone_attr = self
.clone_attr
.set_max_rmr_access(flags_into_ibv_access(max_rmr_access));
self
}
#[inline]
#[must_use]
pub fn set_new_port_num(mut self, port_num: u8) -> Self {
self.clone_attr = self.clone_attr.set_port_num(port_num);
self
}
#[inline]
pub fn set_new_pd(mut self) -> io::Result<Self> {
let new_pd = self.ctx.create_protection_domain()?;
self.clone_attr = self.clone_attr.set_pd(new_pd);
Ok(self)
}
}
#[derive(Debug)]
pub struct RdmaListener {
tcp_listener: TcpListener,
}
impl RdmaListener {
#[inline]
pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let tcp_listener = TcpListener::bind(addr).await?;
Ok(Self { tcp_listener })
}
#[inline]
pub async fn accept(
&self,
port_num: u8,
gid_index: usize,
max_message_length: usize,
) -> io::Result<Rdma> {
let (mut stream, _) = self.tcp_listener.accept().await?;
let mut rdma = RdmaBuilder::default()
.set_port_num(port_num)
.set_gid_index(gid_index)
.build()?;
assert!(
rdma.conn_type == ConnectionType::RCSocket,
"should set connection type to RCSocket"
);
let endpoint_size = bincode::serialized_size(&rdma.endpoint()).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Endpoint serialization failed, {:?}", e),
)
})?;
let mut remote = vec![0_u8; endpoint_size.cast()];
let _ = stream.read_exact(remote.as_mut()).await?;
let remote: QueuePairEndpoint = bincode::deserialize(&remote).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("failed to deserialize remote endpoint, {:?}", e),
)
})?;
let local = bincode::serialize(&rdma.endpoint()).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("failed to deserialize remote endpoint, {:?}", e),
)
})?;
stream.write_all(&local).await?;
rdma.qp_handshake(remote)?;
debug!("handshake done");
rdma.init_agent(max_message_length, *DEFAULT_ACCESS).await?;
Ok(rdma)
}
}