use super::context::RdmaContext;
use super::mr::{RdmaMr, SendCq, wait_for_completion};
use ibverbs_sys::{ibv_qp_attr_mask, ibv_qp_state, ibv_send_flags, ibv_wr_opcode};
use nexar::error::{NexarError, Result};
use std::os::raw::c_int;
use std::ptr;
use std::sync::Arc;
#[derive(Debug, Clone, Copy)]
pub struct RdmaEndpoint {
pub qp_num: u32,
pub lid: u16,
pub gid: [u8; 16],
}
const ENDPOINT_SIZE: usize = 22;
impl RdmaEndpoint {
pub fn to_bytes(&self) -> [u8; ENDPOINT_SIZE] {
let mut buf = [0u8; ENDPOINT_SIZE];
buf[0..4].copy_from_slice(&self.qp_num.to_le_bytes());
buf[4..6].copy_from_slice(&self.lid.to_le_bytes());
buf[6..22].copy_from_slice(&self.gid);
buf
}
pub fn from_bytes(buf: &[u8; ENDPOINT_SIZE]) -> Self {
Self {
qp_num: u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]),
lid: u16::from_le_bytes([buf[4], buf[5]]),
gid: [
buf[6], buf[7], buf[8], buf[9], buf[10], buf[11], buf[12], buf[13], buf[14],
buf[15], buf[16], buf[17], buf[18], buf[19], buf[20], buf[21],
],
}
}
}
pub struct PreparedRdmaConnection {
pub(super) qp: *mut ibverbs_sys::ibv_qp,
pub(super) send_cq: *mut ibverbs_sys::ibv_cq,
pub(super) recv_cq: *mut ibverbs_sys::ibv_cq,
pub(super) local_ep: RdmaEndpoint,
pub(super) ctx: Arc<RdmaContext>,
}
unsafe impl Send for PreparedRdmaConnection {}
unsafe impl Sync for PreparedRdmaConnection {}
impl PreparedRdmaConnection {
pub fn endpoint(&self) -> RdmaEndpoint {
self.local_ep
}
pub fn complete(mut self, remote: RdmaEndpoint) -> Result<RdmaConnection> {
unsafe {
let mut attr: ibverbs_sys::ibv_qp_attr = std::mem::zeroed();
attr.qp_state = ibv_qp_state::IBV_QPS_RTR;
attr.path_mtu = ibverbs_sys::IBV_MTU_4096;
attr.dest_qp_num = remote.qp_num;
attr.rq_psn = 0;
attr.max_dest_rd_atomic = 4;
attr.min_rnr_timer = 12;
attr.ah_attr.is_global = 1;
attr.ah_attr.grh.dgid.raw = remote.gid;
attr.ah_attr.grh.sgid_index = 0;
attr.ah_attr.grh.hop_limit = 64;
attr.ah_attr.grh.traffic_class = 0;
attr.ah_attr.dlid = remote.lid;
attr.ah_attr.sl = 0;
attr.ah_attr.src_path_bits = 0;
attr.ah_attr.port_num = 1;
let mask = ibv_qp_attr_mask::IBV_QP_STATE
| ibv_qp_attr_mask::IBV_QP_AV
| ibv_qp_attr_mask::IBV_QP_PATH_MTU
| ibv_qp_attr_mask::IBV_QP_DEST_QPN
| ibv_qp_attr_mask::IBV_QP_RQ_PSN
| ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
| ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
let rc = ibverbs_sys::ibv_modify_qp(self.qp, &mut attr, mask.0 as c_int);
if rc != 0 {
ibverbs_sys::ibv_destroy_qp(self.qp);
self.qp = ptr::null_mut();
return Err(NexarError::device(format!(
"RDMA: ibv_modify_qp to RTR failed (rc={rc})"
)));
}
let mut attr: ibverbs_sys::ibv_qp_attr = std::mem::zeroed();
attr.qp_state = ibv_qp_state::IBV_QPS_RTS;
attr.sq_psn = 0;
attr.timeout = 14;
attr.retry_cnt = 7;
attr.rnr_retry = 7;
attr.max_rd_atomic = 4;
let mask = ibv_qp_attr_mask::IBV_QP_STATE
| ibv_qp_attr_mask::IBV_QP_TIMEOUT
| ibv_qp_attr_mask::IBV_QP_RETRY_CNT
| ibv_qp_attr_mask::IBV_QP_RNR_RETRY
| ibv_qp_attr_mask::IBV_QP_SQ_PSN
| ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
let rc = ibverbs_sys::ibv_modify_qp(self.qp, &mut attr, mask.0 as c_int);
if rc != 0 {
ibverbs_sys::ibv_destroy_qp(self.qp);
self.qp = ptr::null_mut();
return Err(NexarError::device(format!(
"RDMA: ibv_modify_qp to RTS failed (rc={rc})"
)));
}
let qp = self.qp;
self.qp = ptr::null_mut();
Ok(RdmaConnection {
qp,
send_cq: self.send_cq,
recv_cq: self.recv_cq,
ctx: Arc::clone(&self.ctx),
})
}
}
}
impl Drop for PreparedRdmaConnection {
fn drop(&mut self) {
unsafe {
if !self.qp.is_null() {
ibverbs_sys::ibv_destroy_qp(self.qp);
self.qp = ptr::null_mut();
}
}
}
}
pub struct RdmaConnection {
qp: *mut ibverbs_sys::ibv_qp,
send_cq: *mut ibverbs_sys::ibv_cq,
recv_cq: *mut ibverbs_sys::ibv_cq,
ctx: Arc<RdmaContext>,
}
unsafe impl Send for RdmaConnection {}
unsafe impl Sync for RdmaConnection {}
const CQ_POLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
impl RdmaConnection {
pub async fn send_async(&mut self, mr: &RdmaMr, wr_id: u64) -> Result<()> {
unsafe {
let mut sge: ibverbs_sys::ibv_sge = std::mem::zeroed();
sge.addr = mr.ptr as u64;
sge.length = mr.size as u32;
sge.lkey = mr.lkey();
let mut wr: ibverbs_sys::ibv_send_wr = std::mem::zeroed();
wr.wr_id = wr_id;
wr.sg_list = &mut sge;
wr.num_sge = 1;
wr.opcode = ibv_wr_opcode::IBV_WR_SEND;
wr.send_flags = ibv_send_flags::IBV_SEND_SIGNALED.0;
let mut bad_wr: *mut ibverbs_sys::ibv_send_wr = ptr::null_mut();
let ctx = (*self.qp).context;
let ops = &mut (*ctx).ops;
let rc = ops.post_send.as_mut().expect("post_send missing")(
self.qp,
&mut wr as *mut _,
&mut bad_wr as *mut _,
);
if rc != 0 {
return Err(NexarError::device(format!(
"RDMA: post_send failed (rc={rc})"
)));
}
}
wait_for_completion(
SendCq::new(self.send_cq),
&self.ctx.send_async_fd,
CQ_POLL_TIMEOUT,
)
.await
}
pub async fn recv_async(&mut self, mr: &RdmaMr, wr_id: u64) -> Result<()> {
unsafe {
let mut sge: ibverbs_sys::ibv_sge = std::mem::zeroed();
sge.addr = mr.ptr as u64;
sge.length = mr.size as u32;
sge.lkey = mr.lkey();
let mut wr: ibverbs_sys::ibv_recv_wr = std::mem::zeroed();
wr.wr_id = wr_id;
wr.sg_list = &mut sge;
wr.num_sge = 1;
let mut bad_wr: *mut ibverbs_sys::ibv_recv_wr = ptr::null_mut();
let ctx = (*self.qp).context;
let ops = &mut (*ctx).ops;
let rc = ops.post_recv.as_mut().expect("post_recv missing")(
self.qp,
&mut wr as *mut _,
&mut bad_wr as *mut _,
);
if rc != 0 {
return Err(NexarError::device(format!(
"RDMA: post_recv failed (rc={rc})"
)));
}
}
wait_for_completion(
SendCq::new(self.recv_cq),
&self.ctx.recv_async_fd,
CQ_POLL_TIMEOUT,
)
.await
}
pub fn send(&mut self, mr: &RdmaMr, wr_id: u64) -> Result<()> {
unsafe {
let mut sge: ibverbs_sys::ibv_sge = std::mem::zeroed();
sge.addr = mr.ptr as u64;
sge.length = mr.size as u32;
sge.lkey = mr.lkey();
let mut wr: ibverbs_sys::ibv_send_wr = std::mem::zeroed();
wr.wr_id = wr_id;
wr.sg_list = &mut sge;
wr.num_sge = 1;
wr.opcode = ibv_wr_opcode::IBV_WR_SEND;
wr.send_flags = ibv_send_flags::IBV_SEND_SIGNALED.0;
let mut bad_wr: *mut ibverbs_sys::ibv_send_wr = ptr::null_mut();
let ctx = (*self.qp).context;
let ops = &mut (*ctx).ops;
let rc = ops.post_send.as_mut().expect("post_send missing")(
self.qp,
&mut wr as *mut _,
&mut bad_wr as *mut _,
);
if rc != 0 {
return Err(NexarError::device(format!(
"RDMA: post_send failed (rc={rc})"
)));
}
}
self.wait_send_completion_sync()
}
pub fn recv(&mut self, mr: &RdmaMr, wr_id: u64) -> Result<()> {
unsafe {
let mut sge: ibverbs_sys::ibv_sge = std::mem::zeroed();
sge.addr = mr.ptr as u64;
sge.length = mr.size as u32;
sge.lkey = mr.lkey();
let mut wr: ibverbs_sys::ibv_recv_wr = std::mem::zeroed();
wr.wr_id = wr_id;
wr.sg_list = &mut sge;
wr.num_sge = 1;
let mut bad_wr: *mut ibverbs_sys::ibv_recv_wr = ptr::null_mut();
let ctx = (*self.qp).context;
let ops = &mut (*ctx).ops;
let rc = ops.post_recv.as_mut().expect("post_recv missing")(
self.qp,
&mut wr as *mut _,
&mut bad_wr as *mut _,
);
if rc != 0 {
return Err(NexarError::device(format!(
"RDMA: post_recv failed (rc={rc})"
)));
}
}
self.wait_recv_completion_sync()
}
fn poll_until_complete_sync(
cq: *mut ibverbs_sys::ibv_cq,
timeout: std::time::Duration,
) -> Result<()> {
let start = std::time::Instant::now();
let mut iter = 0u32;
loop {
unsafe {
let mut wc = ibverbs_sys::ibv_wc::default();
let ctx = (*cq).context;
let ops = &mut (*ctx).ops;
let n = ops.poll_cq.as_mut().expect("poll_cq missing")(cq, 1, &mut wc as *mut _);
if n < 0 {
return Err(NexarError::device("RDMA: poll_cq failed"));
}
if n > 0 {
if let Some((status, vendor_err)) = wc.error() {
return Err(NexarError::device(format!(
"RDMA: work completion failed (status={status:?}, vendor_err={vendor_err}, wr_id={})",
wc.wr_id()
)));
}
return Ok(());
}
}
if start.elapsed() > timeout {
return Err(NexarError::device(format!(
"RDMA: CQ poll timed out after {}ms",
timeout.as_millis()
)));
}
if iter < 1000 {
std::hint::spin_loop();
} else if iter < 5000 {
std::thread::sleep(std::time::Duration::from_micros(10));
} else {
std::thread::sleep(std::time::Duration::from_micros(100));
}
iter = iter.saturating_add(1);
}
}
fn wait_send_completion_sync(&self) -> Result<()> {
Self::poll_until_complete_sync(self.send_cq, CQ_POLL_TIMEOUT)
}
fn wait_recv_completion_sync(&self) -> Result<()> {
Self::poll_until_complete_sync(self.recv_cq, CQ_POLL_TIMEOUT)
}
}
impl Drop for RdmaConnection {
fn drop(&mut self) {
unsafe {
if !self.qp.is_null() {
ibverbs_sys::ibv_destroy_qp(self.qp);
}
}
}
}