#![allow(clippy::while_let_on_iterator)]
use core::time;
use std::thread;
use sideway::ibverbs::completion::GenericCompletionQueue;
use sideway::ibverbs::queue_pair::{GenericQueuePair, SendOperationFlags};
use sideway::ibverbs::{
address::{AddressHandleAttribute, GidType},
device,
device_context::Mtu,
queue_pair::{
PostSendGuard, QueuePair, QueuePairAttribute, QueuePairState, SetScatterGatherEntry, WorkRequestFlags,
},
AccessFlags,
};
use rstest::rstest;
#[rstest]
#[case(true)]
#[case(false)]
fn main(#[case] use_qp_ex: bool) -> Result<(), Box<dyn std::error::Error>> {
let device_list = device::DeviceList::new()?;
for device in &device_list {
let ctx = device.open().unwrap();
let pd = ctx.alloc_pd().unwrap();
let mut remote_val: u64 = 42;
let mut local_buf: u64 = 0;
let mr = unsafe {
pd.reg_mr(
&mut remote_val as *mut u64 as _,
std::mem::size_of::<u64>(),
AccessFlags::LocalWrite
| AccessFlags::RemoteWrite
| AccessFlags::RemoteRead
| AccessFlags::RemoteAtomic,
)
.unwrap()
};
let local_mr = unsafe {
pd.reg_mr(
&mut local_buf as *mut u64 as _,
std::mem::size_of::<u64>(),
AccessFlags::LocalWrite,
)
.unwrap()
};
let mut cq_builder = ctx.create_cq_builder();
cq_builder.setup_cqe(128);
let sq = GenericCompletionQueue::from(cq_builder.build_ex().unwrap());
let rq = GenericCompletionQueue::from(cq_builder.build_ex().unwrap());
let mut builder = pd.create_qp_builder();
builder
.setup_max_inline_data(0)
.setup_send_cq(sq.clone())
.setup_recv_cq(rq.clone())
.setup_send_ops_flags(SendOperationFlags::AtomicCompareAndSwap | SendOperationFlags::AtomicFetchAndAdd);
let mut qp: GenericQueuePair = if use_qp_ex {
builder.build_ex().unwrap().into()
} else {
builder.build().unwrap().into()
};
let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::Init)
.setup_pkey_index(0)
.setup_port(1)
.setup_access_flags(
AccessFlags::LocalWrite
| AccessFlags::RemoteWrite
| AccessFlags::RemoteRead
| AccessFlags::RemoteAtomic,
);
qp.modify(&attr).unwrap();
assert_eq!(QueuePairState::Init, qp.state());
let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::ReadyToReceive)
.setup_path_mtu(Mtu::Mtu1024)
.setup_dest_qp_num(qp.qp_number())
.setup_rq_psn(1)
.setup_max_dest_read_atomic(1)
.setup_min_rnr_timer(0);
let mut ah_attr = AddressHandleAttribute::new();
let gid_entries = ctx.query_gid_table().unwrap();
let gid = gid_entries
.iter()
.find(|&&gid| !gid.gid().is_unicast_link_local() || gid.gid_type() == GidType::RoceV1)
.unwrap();
ah_attr
.setup_dest_lid(1)
.setup_port(1)
.setup_service_level(1)
.setup_grh_src_gid_index(gid.gid_index().try_into().unwrap())
.setup_grh_dest_gid(&gid.gid())
.setup_grh_hop_limit(64);
attr.setup_address_vector(&ah_attr);
qp.modify(&attr).unwrap();
assert_eq!(QueuePairState::ReadyToReceive, qp.state());
let mut attr = QueuePairAttribute::new();
attr.setup_state(QueuePairState::ReadyToSend)
.setup_sq_psn(1)
.setup_timeout(12)
.setup_retry_cnt(7)
.setup_rnr_retry(7)
.setup_max_read_atomic(1);
qp.modify(&attr).unwrap();
assert_eq!(QueuePairState::ReadyToSend, qp.state());
{
let mut guard = qp.start_post_send();
let wr_handle = guard
.construct_wr(1, WorkRequestFlags::Signaled)
.setup_atomic_compare_swap(mr.rkey(), &remote_val as *const u64 as u64, 42, 100);
unsafe {
wr_handle.setup_sge(local_mr.lkey(), &local_buf as *const u64 as u64, 8);
}
guard.post().unwrap();
let mut wc_found = false;
for _ in 0..100 {
if let Ok(mut poller) = sq.start_poll() {
if let Some(wc) = poller.next() {
assert_eq!(wc.status(), 0, "CAS failed with status: {}", wc.status());
wc_found = true;
break;
}
}
thread::sleep(time::Duration::from_millis(10));
}
assert!(wc_found, "Timed out waiting for CAS completion");
assert_eq!(remote_val, 100);
assert_eq!(local_buf, 42);
}
{
let mut guard = qp.start_post_send();
let wr_handle = guard
.construct_wr(2, WorkRequestFlags::Signaled)
.setup_atomic_fetch_add(mr.rkey(), &remote_val as *const u64 as u64, 7);
unsafe {
wr_handle.setup_sge(local_mr.lkey(), &local_buf as *const u64 as u64, 8);
}
guard.post().unwrap();
let mut wc_found = false;
for _ in 0..100 {
if let Ok(mut poller) = sq.start_poll() {
if let Some(wc) = poller.next() {
assert_eq!(wc.status(), 0, "FAA failed with status: {}", wc.status());
wc_found = true;
break;
}
}
thread::sleep(time::Duration::from_millis(10));
}
assert!(wc_found, "Timed out waiting for FAA completion");
assert_eq!(remote_val, 107);
assert_eq!(local_buf, 100);
}
}
Ok(())
}