use std::net::SocketAddr;
use std::os::unix::io::RawFd;
use std::sync::Arc;
use rdma_io_sys::ibverbs::*;
use rdma_io_sys::rdmacm::*;
use crate::Result;
use crate::cq::CompletionQueue;
use crate::device::Context;
use crate::error::{from_ptr, from_ret_errno};
use crate::pd::ProtectionDomain;
use crate::qp::QpInitAttr;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PortSpace {
Tcp,
Udp,
Ib,
Ipoib,
}
impl PortSpace {
fn as_raw(self) -> u32 {
match self {
Self::Tcp => RDMA_PS_TCP,
Self::Udp => RDMA_PS_UDP,
Self::Ib => RDMA_PS_IB,
Self::Ipoib => RDMA_PS_IPOIB,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CmEventType {
AddrResolved,
AddrError,
RouteResolved,
RouteError,
ConnectRequest,
ConnectResponse,
ConnectError,
Unreachable,
Rejected,
Established,
Disconnected,
DeviceRemoval,
MulticastJoin,
MulticastError,
AddrChange,
TimewaitExit,
Unknown(u32),
}
impl CmEventType {
fn from_raw(v: u32) -> Self {
match v {
RDMA_CM_EVENT_ADDR_RESOLVED => Self::AddrResolved,
RDMA_CM_EVENT_ADDR_ERROR => Self::AddrError,
RDMA_CM_EVENT_ROUTE_RESOLVED => Self::RouteResolved,
RDMA_CM_EVENT_ROUTE_ERROR => Self::RouteError,
RDMA_CM_EVENT_CONNECT_REQUEST => Self::ConnectRequest,
RDMA_CM_EVENT_CONNECT_RESPONSE => Self::ConnectResponse,
RDMA_CM_EVENT_CONNECT_ERROR => Self::ConnectError,
RDMA_CM_EVENT_UNREACHABLE => Self::Unreachable,
RDMA_CM_EVENT_REJECTED => Self::Rejected,
RDMA_CM_EVENT_ESTABLISHED => Self::Established,
RDMA_CM_EVENT_DISCONNECTED => Self::Disconnected,
RDMA_CM_EVENT_DEVICE_REMOVAL => Self::DeviceRemoval,
RDMA_CM_EVENT_MULTICAST_JOIN => Self::MulticastJoin,
RDMA_CM_EVENT_MULTICAST_ERROR => Self::MulticastError,
RDMA_CM_EVENT_ADDR_CHANGE => Self::AddrChange,
RDMA_CM_EVENT_TIMEWAIT_EXIT => Self::TimewaitExit,
other => Self::Unknown(other),
}
}
}
#[derive(Debug, Clone)]
pub struct ConnParam {
pub responder_resources: u8,
pub initiator_depth: u8,
pub retry_count: u8,
pub rnr_retry_count: u8,
}
impl Default for ConnParam {
fn default() -> Self {
Self {
responder_resources: 1,
initiator_depth: 1,
retry_count: 7,
rnr_retry_count: 7,
}
}
}
impl ConnParam {
fn to_raw(&self) -> rdma_conn_param {
rdma_conn_param {
responder_resources: self.responder_resources,
initiator_depth: self.initiator_depth,
retry_count: self.retry_count,
rnr_retry_count: self.rnr_retry_count,
..Default::default()
}
}
}
pub struct EventChannel {
inner: *mut rdma_event_channel,
}
unsafe impl Send for EventChannel {}
unsafe impl Sync for EventChannel {}
impl Drop for EventChannel {
fn drop(&mut self) {
unsafe { rdma_destroy_event_channel(self.inner) };
}
}
impl EventChannel {
pub fn new() -> Result<Self> {
let ch = from_ptr(unsafe { rdma_create_event_channel() })?;
Ok(Self { inner: ch })
}
pub fn get_event(&self) -> Result<CmEvent> {
let mut event: *mut rdma_cm_event = std::ptr::null_mut();
from_ret_errno(unsafe { rdma_get_cm_event(self.inner, &mut event) })?;
Ok(CmEvent { inner: event })
}
pub fn try_get_event(&self) -> Result<CmEvent> {
let mut event: *mut rdma_cm_event = std::ptr::null_mut();
let ret = unsafe { rdma_get_cm_event(self.inner, &mut event) };
if ret != 0 {
let e = std::io::Error::last_os_error();
if e.kind() == std::io::ErrorKind::WouldBlock {
return Err(crate::Error::WouldBlock);
}
return Err(crate::Error::Verbs(e));
}
Ok(CmEvent { inner: event })
}
pub fn fd(&self) -> RawFd {
unsafe { (*self.inner).fd }
}
pub fn set_nonblocking(&self) -> Result<()> {
let fd = self.fd();
let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
if flags < 0 {
return Err(crate::Error::Verbs(std::io::Error::last_os_error()));
}
let ret = unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
if ret < 0 {
return Err(crate::Error::Verbs(std::io::Error::last_os_error()));
}
Ok(())
}
pub fn as_raw(&self) -> *mut rdma_event_channel {
self.inner
}
}
pub struct CmEvent {
inner: *mut rdma_cm_event,
}
unsafe impl Send for CmEvent {}
impl CmEvent {
pub fn event_type(&self) -> CmEventType {
CmEventType::from_raw(unsafe { (*self.inner).event })
}
pub fn status(&self) -> i32 {
unsafe { (*self.inner).status }
}
pub fn cm_id_raw(&self) -> *mut rdma_cm_id {
unsafe { (*self.inner).id }
}
pub fn ack(self) {
unsafe { rdma_ack_cm_event(self.inner) };
std::mem::forget(self); }
}
impl Drop for CmEvent {
fn drop(&mut self) {
let ret = unsafe { rdma_ack_cm_event(self.inner) };
if ret != 0 {
tracing::error!(
"rdma_ack_cm_event failed: {}",
std::io::Error::last_os_error()
);
}
}
}
pub struct CmId {
pub(crate) inner: *mut rdma_cm_id,
owned: bool,
}
unsafe impl Send for CmId {}
unsafe impl Sync for CmId {}
impl Drop for CmId {
fn drop(&mut self) {
if self.owned {
let ret = unsafe { rdma_destroy_id(self.inner) };
if ret != 0 {
tracing::error!(
"rdma_destroy_id failed: {}",
std::io::Error::last_os_error()
);
}
}
}
}
impl CmId {
pub fn new(channel: &EventChannel, port_space: PortSpace) -> Result<Self> {
let mut id: *mut rdma_cm_id = std::ptr::null_mut();
from_ret_errno(unsafe {
rdma_create_id(
channel.inner,
&mut id,
std::ptr::null_mut(),
port_space.as_raw(),
)
})?;
Ok(Self {
inner: id,
owned: true,
})
}
pub unsafe fn from_raw(id: *mut rdma_cm_id, owned: bool) -> Self {
Self { inner: id, owned }
}
pub fn resolve_addr(
&self,
src: Option<&SocketAddr>,
dst: &SocketAddr,
timeout_ms: i32,
) -> Result<()> {
let (src_ptr, dst_sa) = sockaddr_args(src, dst);
from_ret_errno(unsafe {
rdma_resolve_addr(self.inner, src_ptr, dst_sa.as_ptr() as *mut _, timeout_ms)
})
}
pub fn resolve_route(&self, timeout_ms: i32) -> Result<()> {
from_ret_errno(unsafe { rdma_resolve_route(self.inner, timeout_ms) })
}
pub fn listen(&self, addr: &SocketAddr, backlog: i32) -> Result<()> {
let sa = to_sockaddr_storage(addr);
from_ret_errno(unsafe { rdma_bind_addr(self.inner, sa.as_ptr() as *mut _) })?;
from_ret_errno(unsafe { rdma_listen(self.inner, backlog) })
}
pub fn create_qp(
&self,
pd: &Arc<ProtectionDomain>,
init_attr: &QpInitAttr,
) -> Result<CmQueuePair> {
self.create_qp_with_cq(pd, init_attr, None, None)
}
pub fn create_qp_with_cq(
&self,
pd: &Arc<ProtectionDomain>,
init_attr: &QpInitAttr,
send_cq: Option<&Arc<CompletionQueue>>,
recv_cq: Option<&Arc<CompletionQueue>>,
) -> Result<CmQueuePair> {
let mut raw_attr = ibv_qp_init_attr {
send_cq: send_cq.map_or(std::ptr::null_mut(), |cq| cq.inner),
recv_cq: recv_cq.map_or(std::ptr::null_mut(), |cq| cq.inner),
cap: ibv_qp_cap {
max_send_wr: init_attr.max_send_wr,
max_recv_wr: init_attr.max_recv_wr,
max_send_sge: init_attr.max_send_sge,
max_recv_sge: init_attr.max_recv_sge,
max_inline_data: init_attr.max_inline_data,
},
qp_type: init_attr.qp_type.as_raw(),
sq_sig_all: i32::from(init_attr.sq_sig_all),
..Default::default()
};
from_ret_errno(unsafe { rdma_create_qp(self.inner, pd.inner, &mut raw_attr) })?;
Ok(CmQueuePair {
qp: self.qp_raw(),
cm_id_raw: self.inner,
_pd: Arc::clone(pd),
_send_cq: send_cq.map(Arc::clone),
_recv_cq: recv_cq.map(Arc::clone),
})
}
pub fn connect(&self, param: &ConnParam) -> Result<()> {
let mut raw = param.to_raw();
from_ret_errno(unsafe { rdma_connect(self.inner, &mut raw) })
}
pub fn accept(&self, param: &ConnParam) -> Result<()> {
let mut raw = param.to_raw();
from_ret_errno(unsafe { rdma_accept(self.inner, &mut raw) })
}
pub fn disconnect(&self) -> Result<()> {
from_ret_errno(unsafe { rdma_disconnect(self.inner) })
}
pub fn qp_num(&self) -> Option<u32> {
let qp = unsafe { (*self.inner).qp };
if qp.is_null() {
None
} else {
Some(unsafe { (*qp).qp_num })
}
}
pub fn qp_raw(&self) -> *mut ibv_qp {
unsafe { (*self.inner).qp }
}
pub fn verbs_context(&self) -> Option<Arc<Context>> {
let ctx = unsafe { (*self.inner).verbs };
if ctx.is_null() {
None
} else {
Some(Arc::new(unsafe { Context::from_raw(ctx, false) }))
}
}
pub fn alloc_pd(&self) -> Result<Arc<ProtectionDomain>> {
let ctx = self.verbs_context().ok_or(crate::Error::InvalidArg(
"CM ID has no verbs context (resolve_addr first)".into(),
))?;
ProtectionDomain::new(ctx)
}
pub fn as_raw(&self) -> *mut rdma_cm_id {
self.inner
}
pub fn migrate(&self, new_channel: &EventChannel) -> Result<()> {
from_ret_errno(unsafe { rdma_migrate_id(self.inner, new_channel.as_raw()) })
}
pub fn peer_addr(&self) -> Option<SocketAddr> {
let sa = unsafe { &(*self.inner).route.addr.rdma_addr__anon_1.dst_addr };
unsafe { sockaddr_to_std(sa as *const _ as *const _) }
}
pub fn local_addr(&self) -> Option<SocketAddr> {
let sa = unsafe { &(*self.inner).route.addr.rdma_addr__anon_0.src_addr };
unsafe { sockaddr_to_std(sa as *const _ as *const _) }
}
}
pub struct CmQueuePair {
qp: *mut ibv_qp,
cm_id_raw: *mut rdma_cm_id,
_pd: Arc<ProtectionDomain>,
_send_cq: Option<Arc<CompletionQueue>>,
_recv_cq: Option<Arc<CompletionQueue>>,
}
unsafe impl Send for CmQueuePair {}
unsafe impl Sync for CmQueuePair {}
impl Drop for CmQueuePair {
fn drop(&mut self) {
unsafe { rdma_destroy_qp(self.cm_id_raw) };
}
}
impl CmQueuePair {
pub fn as_raw(&self) -> *mut ibv_qp {
self.qp
}
pub fn qp_num(&self) -> u32 {
unsafe { (*self.qp).qp_num }
}
}
const AF_INET: u16 = 2;
const AF_INET6: u16 = 10;
fn to_sockaddr_storage(addr: &SocketAddr) -> SockAddrBuf {
let mut buf = [0u8; std::mem::size_of::<bnd_linux::libc::posix::socket::sockaddr_storage>()];
match addr {
SocketAddr::V4(v4) => {
let sa = bnd_linux::libc::posix::inet::sockaddr_in {
sin_family: AF_INET,
sin_port: v4.port().to_be(),
sin_addr: bnd_linux::libc::posix::inet::in_addr {
s_addr: u32::from_ne_bytes(v4.ip().octets()),
},
..Default::default()
};
unsafe {
std::ptr::copy_nonoverlapping(
&sa as *const _ as *const u8,
buf.as_mut_ptr(),
std::mem::size_of_val(&sa),
);
}
}
SocketAddr::V6(v6) => {
let sa = bnd_linux::libc::posix::inet::sockaddr_in6 {
sin6_family: AF_INET6,
sin6_port: v6.port().to_be(),
sin6_flowinfo: v6.flowinfo(),
sin6_addr: bnd_linux::libc::posix::inet::in6_addr {
__in6_u: bnd_linux::libc::posix::inet::in6_addr___in6_u {
__u6_addr8: v6.ip().octets(),
},
},
sin6_scope_id: v6.scope_id(),
};
unsafe {
std::ptr::copy_nonoverlapping(
&sa as *const _ as *const u8,
buf.as_mut_ptr(),
std::mem::size_of_val(&sa),
);
}
}
}
SockAddrBuf(buf)
}
struct SockAddrBuf([u8; std::mem::size_of::<bnd_linux::libc::posix::socket::sockaddr_storage>()]);
impl SockAddrBuf {
fn as_ptr(&self) -> *const bnd_linux::libc::posix::socket::sockaddr {
self.0.as_ptr().cast()
}
}
fn sockaddr_args(
src: Option<&SocketAddr>,
dst: &SocketAddr,
) -> (*mut bnd_linux::libc::posix::socket::sockaddr, SockAddrBuf) {
let dst_sa = to_sockaddr_storage(dst);
let src_ptr = match src {
Some(_) => std::ptr::null_mut(),
None => std::ptr::null_mut(),
};
(src_ptr, dst_sa)
}
unsafe fn sockaddr_to_std(
sa: *const bnd_linux::libc::posix::socket::sockaddr,
) -> Option<SocketAddr> {
unsafe {
let family = (*sa).sa_family;
if family == AF_INET {
let sin = &*(sa as *const bnd_linux::libc::posix::inet::sockaddr_in);
let ip = std::net::Ipv4Addr::from(u32::from_be(sin.sin_addr.s_addr));
let port = u16::from_be(sin.sin_port);
Some(SocketAddr::V4(std::net::SocketAddrV4::new(ip, port)))
} else if family == AF_INET6 {
let sin6 = &*(sa as *const bnd_linux::libc::posix::inet::sockaddr_in6);
let ip = std::net::Ipv6Addr::from(sin6.sin6_addr.__in6_u.__u6_addr8);
let port = u16::from_be(sin6.sin6_port);
Some(SocketAddr::V6(std::net::SocketAddrV6::new(
ip,
port,
sin6.sin6_flowinfo,
sin6.sin6_scope_id,
)))
} else {
None
}
}
}