#![deny(
// The following are allowed by default lints according to
// https://doc.rust-lang.org/rustc/lints/listing/allowed-by-default.html
anonymous_parameters,
bare_trait_objects,
// box_pointers, // use box pointer to allocate on heap
// elided_lifetimes_in_paths, // allow anonymous lifetime
missing_copy_implementations,
missing_debug_implementations,
missing_docs, // TODO: add documents
single_use_lifetimes, // TODO: fix lifetime names only used once
trivial_casts, // TODO: remove trivial casts in code
trivial_numeric_casts,
// unreachable_pub, allow clippy::redundant_pub_crate lint instead
// unsafe_code,
unstable_features,
unused_extern_crates,
unused_import_braces,
unused_qualifications,
unused_results,
variant_size_differences,
warnings, // treat all wanings as errors
clippy::all,
clippy::restriction,
clippy::pedantic,
// clippy::nursery, // It's still under development
clippy::cargo,
unreachable_pub,
)]
#![allow(
// Some explicitly allowed Clippy lints, must have clear reason to allow
clippy::blanket_clippy_restriction_lints, // allow clippy::restriction
clippy::implicit_return, // actually omitting the return keyword is idiomatic Rust code
clippy::module_name_repetitions, // repeation of module name in a struct name is not big deal
clippy::multiple_crate_versions, // multi-version dependency crates is not able to fix
clippy::missing_errors_doc, // TODO: add error docs
clippy::missing_panics_doc, // TODO: add panic docs
clippy::panic_in_result_fn,
clippy::shadow_same, // Not too much bad
clippy::shadow_reuse, // Not too much bad
clippy::exhaustive_enums,
clippy::exhaustive_structs,
clippy::indexing_slicing,
)]
mod agent;
mod completion_queue;
mod context;
mod event_channel;
mod event_listener;
mod gid;
mod id;
mod memory_region;
mod memory_window;
mod mr_allocator;
mod protection_domain;
mod queue_pair;
mod work_request;
use agent::Agent;
use clippy_utilities::Cast;
use context::Context;
use enumflags2::{bitflags, BitFlags};
use event_listener::EventListener;
pub use memory_region::{
local::{LocalMr, LocalMrReadAccess, LocalMrWriteAccess},
remote::{RemoteMr, RemoteMrReadAccess, RemoteMrWriteAccess},
MrAccess,
};
use mr_allocator::MrAllocator;
use protection_domain::ProtectionDomain;
use queue_pair::{QueuePair, QueuePairEndpoint};
use rdma_sys::ibv_access_flags;
use std::{alloc::Layout, fmt::Debug, io, sync::Arc, time::Duration};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream, ToSocketAddrs},
};
use tracing::debug;
#[macro_use]
extern crate lazy_static;
#[bitflags]
#[repr(u64)]
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum AccessFlag {
LocalWrite,
RemoteWrite,
RemoteRead,
RemoteAtomic,
MwBind,
ZeroBased,
OnDemand,
HugeTlb,
RelaxOrder,
}
pub struct RdmaBuilder {
dev_name: Option<String>,
access: ibv_access_flags,
cq_size: u32,
gid_index: usize,
port_num: u8,
}
impl RdmaBuilder {
#[must_use]
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn build(&self) -> io::Result<Rdma> {
Rdma::new(
self.dev_name.as_deref(),
self.access,
self.cq_size,
self.port_num,
self.gid_index,
)
}
#[inline]
#[must_use]
pub fn set_dev(mut self, dev: &str) -> Self {
self.dev_name = Some(dev.to_owned());
self
}
#[inline]
#[must_use]
pub fn set_cq_size(mut self, cq_size: u32) -> Self {
self.cq_size = cq_size;
self
}
#[inline]
#[must_use]
pub fn set_gid_index(mut self, gid_index: usize) -> Self {
self.gid_index = gid_index;
self
}
#[inline]
#[must_use]
pub fn set_port_num(mut self, port_num: u8) -> Self {
self.port_num = port_num;
self
}
#[inline]
#[must_use]
pub fn set_access(mut self, flag: BitFlags<AccessFlag>) -> Self {
self.access = ibv_access_flags(0);
if flag.contains(AccessFlag::LocalWrite) {
self.access |= ibv_access_flags::IBV_ACCESS_LOCAL_WRITE;
}
if flag.contains(AccessFlag::RemoteWrite) {
self.access |= ibv_access_flags::IBV_ACCESS_REMOTE_WRITE;
}
if flag.contains(AccessFlag::RemoteRead) {
self.access |= ibv_access_flags::IBV_ACCESS_REMOTE_READ;
}
if flag.contains(AccessFlag::RemoteAtomic) {
self.access |= ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
}
if flag.contains(AccessFlag::MwBind) {
self.access |= ibv_access_flags::IBV_ACCESS_MW_BIND;
}
if flag.contains(AccessFlag::ZeroBased) {
self.access |= ibv_access_flags::IBV_ACCESS_ZERO_BASED;
}
if flag.contains(AccessFlag::OnDemand) {
self.access |= ibv_access_flags::IBV_ACCESS_ON_DEMAND;
}
if flag.contains(AccessFlag::HugeTlb) {
self.access |= ibv_access_flags::IBV_ACCESS_HUGETLB;
}
if flag.contains(AccessFlag::RelaxOrder) {
self.access |= ibv_access_flags::IBV_ACCESS_RELAXED_ORDERING;
}
self
}
}
impl Debug for RdmaBuilder {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RdmaBuilder")
.field("dev_name", &self.dev_name)
.field("cq_size", &self.cq_size)
.finish()
}
}
impl Default for RdmaBuilder {
#[inline]
fn default() -> Self {
Self {
dev_name: None,
access: ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
| ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
| ibv_access_flags::IBV_ACCESS_REMOTE_READ
| ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC,
cq_size: 16,
gid_index: 0,
port_num: 1,
}
}
}
#[derive(Debug)]
pub struct Rdma {
#[allow(dead_code)]
ctx: Arc<Context>,
#[allow(dead_code)]
pd: Arc<ProtectionDomain>,
allocator: Arc<MrAllocator>,
qp: Arc<QueuePair>,
agent: Option<Arc<Agent>>,
}
impl Rdma {
fn new(
dev_name: Option<&str>,
access: ibv_access_flags,
cq_size: u32,
port_num: u8,
gid_index: usize,
) -> io::Result<Self> {
let ctx = Arc::new(Context::open(dev_name, port_num, gid_index)?);
let ec = ctx.create_event_channel()?;
let cq = Arc::new(ctx.create_completion_queue(cq_size, ec)?);
let event_listener = EventListener::new(cq);
let pd = Arc::new(ctx.create_protection_domain()?);
let allocator = Arc::new(MrAllocator::new(Arc::<ProtectionDomain>::clone(&pd)));
let qp = Arc::new(
pd.create_queue_pair_builder()
.set_event_listener(event_listener)
.set_port_num(port_num)
.set_gid_index(gid_index)
.build()?,
);
qp.modify_to_init(access, port_num)?;
Ok(Self {
ctx,
pd,
qp,
agent: None,
allocator,
})
}
fn endpoint(&self) -> QueuePairEndpoint {
self.qp.endpoint()
}
fn qp_handshake(&mut self, remote: QueuePairEndpoint) -> io::Result<()> {
self.qp.modify_to_rtr(remote, 0, 1, 0x12)?;
debug!("rtr");
self.qp.modify_to_rts(0x12, 6, 6, 0, 1)?;
debug!("rts");
Ok(())
}
#[inline]
pub async fn send(&self, lm: &LocalMr) -> io::Result<()> {
self.agent
.as_ref()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Agent is not ready"))?
.send_data(lm, None)
.await
}
#[inline]
pub async fn send_with_imm(&self, lm: &LocalMr, imm: u32) -> io::Result<()> {
self.agent
.as_ref()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Agent is not ready"))?
.send_data(lm, Some(imm))
.await
}
#[inline]
pub async fn receive(&self) -> io::Result<LocalMr> {
let (lmr, _) = self.receive_with_imm().await?;
Ok(lmr)
}
#[inline]
pub async fn receive_with_imm(&self) -> io::Result<(LocalMr, Option<u32>)> {
self.agent
.as_ref()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Agent is not ready"))?
.receive_data()
.await
}
#[inline]
pub async fn receive_write_imm(&self) -> io::Result<u32> {
self.agent
.as_ref()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Agent is not ready"))?
.receive_imm()
.await
}
#[inline]
pub async fn read<LW, RR>(&self, lm: &mut LW, rm: &RR) -> io::Result<()>
where
LW: LocalMrWriteAccess,
RR: RemoteMrReadAccess,
{
self.qp.read(lm, rm).await
}
#[inline]
pub async fn write<LR, RW>(&self, lm: &LR, rm: &mut RW) -> io::Result<()>
where
LR: LocalMrReadAccess,
RW: RemoteMrWriteAccess,
{
self.qp.write(lm, rm, None).await
}
#[inline]
pub async fn write_with_imm<LR, RW>(&self, lm: &LR, rm: &mut RW, imm: u32) -> io::Result<()>
where
LR: LocalMrReadAccess,
RW: RemoteMrWriteAccess,
{
self.qp.write(lm, rm, Some(imm)).await
}
#[inline]
pub async fn connect<A: ToSocketAddrs>(
addr: A,
port_num: u8,
gid_index: usize,
max_message_length: usize,
) -> io::Result<Self> {
let mut rdma = RdmaBuilder::default()
.set_port_num(port_num)
.set_gid_index(gid_index)
.build()?;
let mut stream = TcpStream::connect(addr).await?;
let mut endpoint = bincode::serialize(&rdma.endpoint()).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("failed to serailize the endpoint, {:?}", e),
)
})?;
stream.write_all(&endpoint).await?;
let _ = stream.read_exact(endpoint.as_mut()).await?;
let remote: QueuePairEndpoint = bincode::deserialize(&endpoint).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("failed to deserailize the endpoint, {:?}", e),
)
})?;
rdma.qp_handshake(remote)?;
let agent = Arc::new(Agent::new(
Arc::<QueuePair>::clone(&rdma.qp),
Arc::<MrAllocator>::clone(&rdma.allocator),
max_message_length,
)?);
rdma.agent = Some(agent);
tokio::time::sleep(Duration::from_secs(1)).await;
Ok(rdma)
}
#[inline]
pub fn alloc_local_mr(&self, layout: Layout) -> io::Result<LocalMr> {
self.allocator.alloc(&layout)
}
#[inline]
pub async fn request_remote_mr(&self, layout: Layout) -> io::Result<RemoteMr> {
if let Some(ref agent) = self.agent {
agent.request_remote_mr(layout).await
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Agent is not ready, please wait a while",
))
}
}
#[inline]
pub async fn send_local_mr(&self, mr: LocalMr) -> io::Result<()> {
if let Some(ref agent) = self.agent {
agent.send_local_mr(mr).await
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Agent is not ready, please wait a while",
))
}
}
#[inline]
pub async fn send_remote_mr(&self, mr: RemoteMr) -> io::Result<()> {
if let Some(ref agent) = self.agent {
agent.send_remote_mr(mr).await
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Agent is not ready, please wait a while",
))
}
}
#[inline]
pub async fn receive_local_mr(&self) -> io::Result<LocalMr> {
if let Some(ref agent) = self.agent {
agent.receive_local_mr().await
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Agent is not ready, please wait a while",
))
}
}
#[inline]
pub async fn receive_remote_mr(&self) -> io::Result<RemoteMr> {
if let Some(ref agent) = self.agent {
agent.receive_remote_mr().await
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Agent is not ready, please wait a while",
))
}
}
}
#[derive(Debug)]
pub struct RdmaListener {
tcp_listener: TcpListener,
}
impl RdmaListener {
#[inline]
pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let tcp_listener = TcpListener::bind(addr).await?;
Ok(Self { tcp_listener })
}
#[inline]
pub async fn accept(
&self,
port_num: u8,
gid_index: usize,
max_message_length: usize,
) -> io::Result<Rdma> {
let (mut stream, _) = self.tcp_listener.accept().await?;
let mut rdma = RdmaBuilder::default()
.set_port_num(port_num)
.set_gid_index(gid_index)
.build()?;
let endpoint_size = bincode::serialized_size(&rdma.endpoint()).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Endpoint serialization failed, {:?}", e),
)
})?;
let mut remote = vec![0_u8; endpoint_size.cast()];
let _ = stream.read_exact(remote.as_mut()).await?;
let remote: QueuePairEndpoint = bincode::deserialize(&remote).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("failed to deserialize remote endpoint, {:?}", e),
)
})?;
let local = bincode::serialize(&rdma.endpoint()).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("failed to deserialize remote endpoint, {:?}", e),
)
})?;
stream.write_all(&local).await?;
rdma.qp_handshake(remote)?;
debug!("handshake done");
let agent = Arc::new(Agent::new(
Arc::<QueuePair>::clone(&rdma.qp),
Arc::<MrAllocator>::clone(&rdma.allocator),
max_message_length,
)?);
rdma.agent = Some(agent);
tokio::time::sleep(Duration::from_secs(1)).await;
Ok(rdma)
}
}