Skip to main content

rdma_io/
wr.rs

1//! Work Request builders and related types.
2
3use rdma_io_sys::ibverbs::*;
4
5/// QP type enum (typed wrapper over `ibv_qp_type`).
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum QpType {
8    Rc,
9    Uc,
10    Ud,
11    XrcSend,
12    XrcRecv,
13    RawPacket,
14    Driver,
15}
16
17impl QpType {
18    /// Convert to the raw `ibv_qp_type` constant.
19    pub fn as_raw(self) -> u32 {
20        match self {
21            Self::Rc => IBV_QPT_RC,
22            Self::Uc => IBV_QPT_UC,
23            Self::Ud => IBV_QPT_UD,
24            Self::XrcSend => IBV_QPT_XRC_SEND,
25            Self::XrcRecv => IBV_QPT_XRC_RECV,
26            Self::RawPacket => IBV_QPT_RAW_PACKET,
27            Self::Driver => IBV_QPT_DRIVER,
28        }
29    }
30
31    /// Convert from a raw `ibv_qp_type` value.
32    pub fn from_raw(v: u32) -> Option<Self> {
33        match v {
34            IBV_QPT_RC => Some(Self::Rc),
35            IBV_QPT_UC => Some(Self::Uc),
36            IBV_QPT_UD => Some(Self::Ud),
37            IBV_QPT_XRC_SEND => Some(Self::XrcSend),
38            IBV_QPT_XRC_RECV => Some(Self::XrcRecv),
39            IBV_QPT_RAW_PACKET => Some(Self::RawPacket),
40            IBV_QPT_DRIVER => Some(Self::Driver),
41            _ => None,
42        }
43    }
44}
45
46/// QP state enum.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum QpState {
49    Reset,
50    Init,
51    Rtr,
52    Rts,
53    Sqd,
54    Sqe,
55    Err,
56    Unknown,
57}
58
59impl QpState {
60    /// Convert to raw `ibv_qp_state`.
61    pub fn as_raw(self) -> u32 {
62        match self {
63            Self::Reset => IBV_QPS_RESET,
64            Self::Init => IBV_QPS_INIT,
65            Self::Rtr => IBV_QPS_RTR,
66            Self::Rts => IBV_QPS_RTS,
67            Self::Sqd => IBV_QPS_SQD,
68            Self::Sqe => IBV_QPS_SQE,
69            Self::Err => IBV_QPS_ERR,
70            Self::Unknown => IBV_QPS_UNKNOWN,
71        }
72    }
73
74    /// Convert from raw value.
75    pub fn from_raw(v: u32) -> Self {
76        match v {
77            IBV_QPS_RESET => Self::Reset,
78            IBV_QPS_INIT => Self::Init,
79            IBV_QPS_RTR => Self::Rtr,
80            IBV_QPS_RTS => Self::Rts,
81            IBV_QPS_SQD => Self::Sqd,
82            IBV_QPS_SQE => Self::Sqe,
83            IBV_QPS_ERR => Self::Err,
84            _ => Self::Unknown,
85        }
86    }
87}
88
89bitflags::bitflags! {
90    /// Send flags for work requests.
91    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
92    pub struct SendFlags: u32 {
93        const FENCE = IBV_SEND_FENCE;
94        const SIGNALED = IBV_SEND_SIGNALED;
95        const SOLICITED = IBV_SEND_SOLICITED;
96        const INLINE = IBV_SEND_INLINE;
97        const IP_CSUM = IBV_SEND_IP_CSUM;
98    }
99}
100
101/// Scatter-Gather Entry — describes a memory buffer for a WR.
102#[repr(transparent)]
103#[derive(Clone, Copy, Default)]
104pub struct Sge {
105    pub(crate) inner: ibv_sge,
106}
107
108impl std::fmt::Debug for Sge {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        f.debug_struct("Sge")
111            .field("addr", &self.inner.addr)
112            .field("length", &self.inner.length)
113            .field("lkey", &self.inner.lkey)
114            .finish()
115    }
116}
117
118impl Sge {
119    /// Create a new SGE.
120    pub fn new(addr: u64, length: u32, lkey: u32) -> Self {
121        Self {
122            inner: ibv_sge { addr, length, lkey },
123        }
124    }
125}
126
127/// Builder for a receive work request.
128pub struct RecvWr {
129    pub(crate) wr_id: u64,
130    pub(crate) sges: Vec<Sge>,
131}
132
133impl RecvWr {
134    /// Create a new receive WR with the given WR id.
135    pub fn new(wr_id: u64) -> Self {
136        Self {
137            wr_id,
138            sges: Vec::new(),
139        }
140    }
141
142    /// Add a scatter-gather entry.
143    pub fn sg(mut self, sge: Sge) -> Self {
144        self.sges.push(sge);
145        self
146    }
147
148    /// Build the raw `ibv_recv_wr`. The caller must ensure `sges` outlives usage.
149    pub(crate) fn build_raw(&mut self) -> ibv_recv_wr {
150        ibv_recv_wr {
151            wr_id: self.wr_id,
152            next: std::ptr::null_mut(),
153            sg_list: if self.sges.is_empty() {
154                std::ptr::null_mut()
155            } else {
156                self.sges.as_mut_ptr().cast()
157            },
158            num_sge: self.sges.len() as i32,
159        }
160    }
161}
162
163/// Opcode for send work requests.
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum WrOpcode {
166    Send,
167    SendWithImm(u32),
168    RdmaWrite,
169    RdmaWriteWithImm(u32),
170    RdmaRead,
171    AtomicCmpAndSwp,
172    AtomicFetchAndAdd,
173    /// Bind a Memory Window to an MR sub-region (Type 2).
174    BindMw,
175    /// Invalidate a Memory Window's rkey (makes it unusable for remote access).
176    LocalInv,
177}
178
179impl WrOpcode {
180    fn as_raw(self) -> u32 {
181        match self {
182            Self::Send => IBV_WR_SEND,
183            Self::SendWithImm(_) => IBV_WR_SEND_WITH_IMM,
184            Self::RdmaWrite => IBV_WR_RDMA_WRITE,
185            Self::RdmaWriteWithImm(_) => IBV_WR_RDMA_WRITE_WITH_IMM,
186            Self::RdmaRead => IBV_WR_RDMA_READ,
187            Self::AtomicCmpAndSwp => IBV_WR_ATOMIC_CMP_AND_SWP,
188            Self::AtomicFetchAndAdd => IBV_WR_ATOMIC_FETCH_AND_ADD,
189            Self::BindMw => IBV_WR_BIND_MW,
190            Self::LocalInv => IBV_WR_LOCAL_INV,
191        }
192    }
193}
194
195/// Builder for a send work request.
196pub struct SendWr {
197    pub(crate) wr_id: u64,
198    pub(crate) opcode: WrOpcode,
199    pub(crate) send_flags: SendFlags,
200    pub(crate) sges: Vec<Sge>,
201    pub(crate) rdma_remote_addr: u64,
202    pub(crate) rdma_rkey: u32,
203    pub(crate) atomic_compare_add: u64,
204    pub(crate) atomic_swap: u64,
205    // MW bind fields (for BindMw opcode)
206    pub(crate) bind_mw_mw: *mut ibv_mw,
207    pub(crate) bind_mw_rkey: u32,
208    pub(crate) bind_mw_bind_info: ibv_mw_bind_info,
209    // Local invalidation (for LocalInv opcode)
210    pub(crate) invalidate_rkey: u32,
211}
212
213// Safety: The raw pointers (*mut ibv_mw, *mut ibv_mr in bind_info) are RDMA
214// kernel-managed handles, safe to send between threads — same justification
215// as OwnedMemoryRegion which also holds *mut ibv_mr.
216unsafe impl Send for SendWr {}
217
218impl SendWr {
219    /// Create a new send WR.
220    pub fn new(wr_id: u64, opcode: WrOpcode) -> Self {
221        Self {
222            wr_id,
223            opcode,
224            send_flags: SendFlags::empty(),
225            sges: Vec::new(),
226            rdma_remote_addr: 0,
227            rdma_rkey: 0,
228            atomic_compare_add: 0,
229            atomic_swap: 0,
230            bind_mw_mw: std::ptr::null_mut(),
231            bind_mw_rkey: 0,
232            bind_mw_bind_info: ibv_mw_bind_info::default(),
233            invalidate_rkey: 0,
234        }
235    }
236
237    /// Set send flags.
238    pub fn flags(mut self, flags: SendFlags) -> Self {
239        self.send_flags = flags;
240        self
241    }
242
243    /// Add a scatter-gather entry.
244    pub fn sg(mut self, sge: Sge) -> Self {
245        self.sges.push(sge);
246        self
247    }
248
249    /// Set RDMA remote address and rkey (for RDMA read/write ops).
250    pub fn rdma(mut self, remote_addr: u64, rkey: u32) -> Self {
251        self.rdma_remote_addr = remote_addr;
252        self.rdma_rkey = rkey;
253        self
254    }
255
256    /// Set atomic operation parameters (for CAS and FAA).
257    pub fn atomic(mut self, remote_addr: u64, rkey: u32, compare_add: u64, swap: u64) -> Self {
258        self.rdma_remote_addr = remote_addr;
259        self.rdma_rkey = rkey;
260        self.atomic_compare_add = compare_add;
261        self.atomic_swap = swap;
262        self
263    }
264
265    /// Set Memory Window bind parameters (for BindMw opcode).
266    ///
267    /// # Arguments
268    /// * `mw` - Raw MW pointer to bind
269    /// * `rkey` - New rkey to assign to the MW after binding
270    /// * `mr` - Raw MR pointer that the MW will be bound to
271    /// * `addr` - Start address within the MR
272    /// * `length` - Length of the bound region
273    /// * `access` - Access flags for the MW binding
274    pub fn bind_mw(
275        mut self,
276        mw: *mut ibv_mw,
277        rkey: u32,
278        mr: *mut ibv_mr,
279        addr: u64,
280        length: u64,
281        access: u32,
282    ) -> Self {
283        self.bind_mw_mw = mw;
284        self.bind_mw_rkey = rkey;
285        self.bind_mw_bind_info = ibv_mw_bind_info {
286            mr,
287            addr,
288            length,
289            mw_access_flags: access,
290        };
291        self
292    }
293
294    /// Set the rkey to invalidate (for LocalInv opcode).
295    pub fn inv_rkey(mut self, rkey: u32) -> Self {
296        self.invalidate_rkey = rkey;
297        self
298    }
299
300    /// Build the raw `ibv_send_wr`. The caller must ensure `sges` outlives usage.
301    pub(crate) fn build_raw(&mut self) -> ibv_send_wr {
302        let sg_list = if self.sges.is_empty() {
303            std::ptr::null_mut()
304        } else {
305            self.sges.as_mut_ptr().cast()
306        };
307        let mut wr = ibv_send_wr {
308            wr_id: self.wr_id,
309            opcode: self.opcode.as_raw(),
310            send_flags: self.send_flags.bits(),
311            sg_list,
312            num_sge: self.sges.len() as i32,
313            next: std::ptr::null_mut(),
314            ..Default::default()
315        };
316
317        // Set immediate data if applicable.
318        match self.opcode {
319            WrOpcode::SendWithImm(imm) | WrOpcode::RdmaWriteWithImm(imm) => {
320                wr.ibv_send_wr__anon_0.imm_data = imm;
321            }
322            WrOpcode::LocalInv => {
323                wr.ibv_send_wr__anon_0.invalidate_rkey = self.invalidate_rkey;
324            }
325            _ => {}
326        }
327
328        // Set RDMA fields.
329        match self.opcode {
330            WrOpcode::RdmaWrite | WrOpcode::RdmaWriteWithImm(_) | WrOpcode::RdmaRead => {
331                wr.wr.rdma = ibv_send_wr_wr_rdma {
332                    remote_addr: self.rdma_remote_addr,
333                    rkey: self.rdma_rkey,
334                };
335            }
336            WrOpcode::AtomicCmpAndSwp | WrOpcode::AtomicFetchAndAdd => {
337                wr.wr.atomic = ibv_send_wr_wr_atomic {
338                    remote_addr: self.rdma_remote_addr,
339                    compare_add: self.atomic_compare_add,
340                    swap: self.atomic_swap,
341                    rkey: self.rdma_rkey,
342                };
343            }
344            WrOpcode::BindMw => {
345                wr.ibv_send_wr__anon_1.bind_mw = ibv_send_wr__anon_1_bind_mw {
346                    mw: self.bind_mw_mw,
347                    rkey: self.bind_mw_rkey,
348                    bind_info: self.bind_mw_bind_info,
349                };
350            }
351            _ => {}
352        }
353
354        wr
355    }
356}