Skip to main content

rdma_io/
cm.rs

1//! RDMA Connection Manager (rdma_cm).
2//!
3//! Provides TCP-like connection semantics over RDMA. This is the primary
4//! way to use iWARP (siw) and the recommended approach for RoCE.
5
6use std::net::SocketAddr;
7use std::os::unix::io::RawFd;
8use std::sync::Arc;
9
10use rdma_io_sys::ibverbs::*;
11use rdma_io_sys::rdmacm::*;
12
13use crate::Result;
14use crate::cq::CompletionQueue;
15use crate::device::Context;
16use crate::error::{from_ptr, from_ret_errno};
17use crate::pd::ProtectionDomain;
18use crate::qp::QpInitAttr;
19
20/// Port space for CM connections.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum PortSpace {
23    Tcp,
24    Udp,
25    Ib,
26    Ipoib,
27}
28
29impl PortSpace {
30    fn as_raw(self) -> u32 {
31        match self {
32            Self::Tcp => RDMA_PS_TCP,
33            Self::Udp => RDMA_PS_UDP,
34            Self::Ib => RDMA_PS_IB,
35            Self::Ipoib => RDMA_PS_IPOIB,
36        }
37    }
38}
39
40/// CM event types.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum CmEventType {
43    AddrResolved,
44    AddrError,
45    RouteResolved,
46    RouteError,
47    ConnectRequest,
48    ConnectResponse,
49    ConnectError,
50    Unreachable,
51    Rejected,
52    Established,
53    Disconnected,
54    DeviceRemoval,
55    MulticastJoin,
56    MulticastError,
57    AddrChange,
58    TimewaitExit,
59    Unknown(u32),
60}
61
62impl CmEventType {
63    fn from_raw(v: u32) -> Self {
64        match v {
65            RDMA_CM_EVENT_ADDR_RESOLVED => Self::AddrResolved,
66            RDMA_CM_EVENT_ADDR_ERROR => Self::AddrError,
67            RDMA_CM_EVENT_ROUTE_RESOLVED => Self::RouteResolved,
68            RDMA_CM_EVENT_ROUTE_ERROR => Self::RouteError,
69            RDMA_CM_EVENT_CONNECT_REQUEST => Self::ConnectRequest,
70            RDMA_CM_EVENT_CONNECT_RESPONSE => Self::ConnectResponse,
71            RDMA_CM_EVENT_CONNECT_ERROR => Self::ConnectError,
72            RDMA_CM_EVENT_UNREACHABLE => Self::Unreachable,
73            RDMA_CM_EVENT_REJECTED => Self::Rejected,
74            RDMA_CM_EVENT_ESTABLISHED => Self::Established,
75            RDMA_CM_EVENT_DISCONNECTED => Self::Disconnected,
76            RDMA_CM_EVENT_DEVICE_REMOVAL => Self::DeviceRemoval,
77            RDMA_CM_EVENT_MULTICAST_JOIN => Self::MulticastJoin,
78            RDMA_CM_EVENT_MULTICAST_ERROR => Self::MulticastError,
79            RDMA_CM_EVENT_ADDR_CHANGE => Self::AddrChange,
80            RDMA_CM_EVENT_TIMEWAIT_EXIT => Self::TimewaitExit,
81            other => Self::Unknown(other),
82        }
83    }
84}
85
86/// Connection parameters for `connect` / `accept`.
87#[derive(Debug, Clone)]
88pub struct ConnParam {
89    /// Responder resources (max incoming RDMA read/atomic).
90    pub responder_resources: u8,
91    /// Initiator depth (max outstanding RDMA read/atomic).
92    pub initiator_depth: u8,
93    /// Retry count.
94    pub retry_count: u8,
95    /// RNR retry count (7 = infinite).
96    pub rnr_retry_count: u8,
97}
98
99impl Default for ConnParam {
100    fn default() -> Self {
101        Self {
102            responder_resources: 1,
103            initiator_depth: 1,
104            retry_count: 7,
105            rnr_retry_count: 7,
106        }
107    }
108}
109
110impl ConnParam {
111    fn to_raw(&self) -> rdma_conn_param {
112        rdma_conn_param {
113            responder_resources: self.responder_resources,
114            initiator_depth: self.initiator_depth,
115            retry_count: self.retry_count,
116            rnr_retry_count: self.rnr_retry_count,
117            ..Default::default()
118        }
119    }
120}
121
122/// An rdma_cm event channel.
123///
124/// Used to receive CM events (address resolved, connected, etc.).
125pub struct EventChannel {
126    inner: *mut rdma_event_channel,
127}
128
129// Safety: The event channel fd is process-global; get_event serialized by caller.
130unsafe impl Send for EventChannel {}
131unsafe impl Sync for EventChannel {}
132
133impl Drop for EventChannel {
134    fn drop(&mut self) {
135        unsafe { rdma_destroy_event_channel(self.inner) };
136    }
137}
138
139impl EventChannel {
140    /// Create a new event channel.
141    pub fn new() -> Result<Self> {
142        let ch = from_ptr(unsafe { rdma_create_event_channel() })?;
143        Ok(Self { inner: ch })
144    }
145
146    /// Block until the next CM event arrives.
147    pub fn get_event(&self) -> Result<CmEvent> {
148        let mut event: *mut rdma_cm_event = std::ptr::null_mut();
149        from_ret_errno(unsafe { rdma_get_cm_event(self.inner, &mut event) })?;
150        Ok(CmEvent { inner: event })
151    }
152
153    /// Non-blocking get_event. Returns `Error::WouldBlock` if no event is pending.
154    pub fn try_get_event(&self) -> Result<CmEvent> {
155        let mut event: *mut rdma_cm_event = std::ptr::null_mut();
156        let ret = unsafe { rdma_get_cm_event(self.inner, &mut event) };
157        if ret != 0 {
158            let e = std::io::Error::last_os_error();
159            if e.kind() == std::io::ErrorKind::WouldBlock {
160                return Err(crate::Error::WouldBlock);
161            }
162            return Err(crate::Error::Verbs(e));
163        }
164        Ok(CmEvent { inner: event })
165    }
166
167    /// Raw file descriptor for async reactor registration.
168    pub fn fd(&self) -> RawFd {
169        unsafe { (*self.inner).fd }
170    }
171
172    /// Set the event channel fd to non-blocking mode.
173    pub fn set_nonblocking(&self) -> Result<()> {
174        let fd = self.fd();
175        let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
176        if flags < 0 {
177            return Err(crate::Error::Verbs(std::io::Error::last_os_error()));
178        }
179        let ret = unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
180        if ret < 0 {
181            return Err(crate::Error::Verbs(std::io::Error::last_os_error()));
182        }
183        Ok(())
184    }
185
186    /// Raw pointer.
187    pub fn as_raw(&self) -> *mut rdma_event_channel {
188        self.inner
189    }
190}
191
192/// A CM event received from an [`EventChannel`].
193///
194/// **Must be acknowledged** via [`ack`](CmEvent::ack) (consumed on ack).
195pub struct CmEvent {
196    inner: *mut rdma_cm_event,
197}
198
199// Safety: CM events are tied to a channel; single-threaded access pattern.
200unsafe impl Send for CmEvent {}
201
202impl CmEvent {
203    /// The event type.
204    pub fn event_type(&self) -> CmEventType {
205        CmEventType::from_raw(unsafe { (*self.inner).event })
206    }
207
208    /// Status code (0 = success for most events).
209    pub fn status(&self) -> i32 {
210        unsafe { (*self.inner).status }
211    }
212
213    /// For `ConnectRequest` events, the CM ID of the new incoming connection.
214    ///
215    /// The returned `CmId` is **not** owned — it must be accepted or rejected.
216    /// The caller takes ownership after `accept`.
217    pub fn cm_id_raw(&self) -> *mut rdma_cm_id {
218        unsafe { (*self.inner).id }
219    }
220
221    /// Acknowledge and consume this event. Must be called for every event.
222    pub fn ack(self) {
223        unsafe { rdma_ack_cm_event(self.inner) };
224        std::mem::forget(self); // prevent double-free in Drop
225    }
226}
227
228impl Drop for CmEvent {
229    fn drop(&mut self) {
230        // If not ack'd, ack now to avoid leaking the event.
231        let ret = unsafe { rdma_ack_cm_event(self.inner) };
232        if ret != 0 {
233            tracing::error!(
234                "rdma_ack_cm_event failed: {}",
235                std::io::Error::last_os_error()
236            );
237        }
238    }
239}
240
241/// An RDMA CM identifier — the core handle for connection management.
242///
243/// Wraps `rdma_cm_id*`. Use for both client (active) and server (passive) sides.
244pub struct CmId {
245    pub(crate) inner: *mut rdma_cm_id,
246    /// Whether this CmId owns the underlying rdma_cm_id (should call rdma_destroy_id).
247    owned: bool,
248}
249
250// Safety: rdma_cm_id operations are serialized by the caller.
251unsafe impl Send for CmId {}
252unsafe impl Sync for CmId {}
253
254impl Drop for CmId {
255    fn drop(&mut self) {
256        if self.owned {
257            let ret = unsafe { rdma_destroy_id(self.inner) };
258            if ret != 0 {
259                tracing::error!(
260                    "rdma_destroy_id failed: {}",
261                    std::io::Error::last_os_error()
262                );
263            }
264        }
265    }
266}
267
268impl CmId {
269    /// Create a new CM ID on the given event channel.
270    pub fn new(channel: &EventChannel, port_space: PortSpace) -> Result<Self> {
271        let mut id: *mut rdma_cm_id = std::ptr::null_mut();
272        from_ret_errno(unsafe {
273            rdma_create_id(
274                channel.inner,
275                &mut id,
276                std::ptr::null_mut(),
277                port_space.as_raw(),
278            )
279        })?;
280        Ok(Self {
281            inner: id,
282            owned: true,
283        })
284    }
285
286    /// Wrap a raw `rdma_cm_id` pointer (e.g. from a connect request event).
287    ///
288    /// # Safety
289    /// The caller must ensure the pointer is valid and that ownership semantics
290    /// are correctly handled.
291    pub unsafe fn from_raw(id: *mut rdma_cm_id, owned: bool) -> Self {
292        Self { inner: id, owned }
293    }
294
295    /// Resolve the destination address.
296    pub fn resolve_addr(
297        &self,
298        src: Option<&SocketAddr>,
299        dst: &SocketAddr,
300        timeout_ms: i32,
301    ) -> Result<()> {
302        let (src_ptr, dst_sa) = sockaddr_args(src, dst);
303        from_ret_errno(unsafe {
304            rdma_resolve_addr(self.inner, src_ptr, dst_sa.as_ptr() as *mut _, timeout_ms)
305        })
306    }
307
308    /// Resolve the route to the destination.
309    pub fn resolve_route(&self, timeout_ms: i32) -> Result<()> {
310        from_ret_errno(unsafe { rdma_resolve_route(self.inner, timeout_ms) })
311    }
312
313    /// Bind to a local address and start listening.
314    pub fn listen(&self, addr: &SocketAddr, backlog: i32) -> Result<()> {
315        let sa = to_sockaddr_storage(addr);
316        from_ret_errno(unsafe { rdma_bind_addr(self.inner, sa.as_ptr() as *mut _) })?;
317        from_ret_errno(unsafe { rdma_listen(self.inner, backlog) })
318    }
319
320    /// Create a QP on this CM ID, returning an owned [`CmQueuePair`].
321    ///
322    /// When send_cq/recv_cq are None, rdma_cm creates internal CQs.
323    pub fn create_qp(
324        &self,
325        pd: &Arc<ProtectionDomain>,
326        init_attr: &QpInitAttr,
327    ) -> Result<CmQueuePair> {
328        self.create_qp_with_cq(pd, init_attr, None, None)
329    }
330
331    /// Create a QP on this CM ID with explicit CQs, returning an owned [`CmQueuePair`].
332    ///
333    /// Use this when you need to control CQ creation (e.g., for completion
334    /// channel notification). Pass `None` to let rdma_cm create internal CQs.
335    pub fn create_qp_with_cq(
336        &self,
337        pd: &Arc<ProtectionDomain>,
338        init_attr: &QpInitAttr,
339        send_cq: Option<&Arc<CompletionQueue>>,
340        recv_cq: Option<&Arc<CompletionQueue>>,
341    ) -> Result<CmQueuePair> {
342        let mut raw_attr = ibv_qp_init_attr {
343            send_cq: send_cq.map_or(std::ptr::null_mut(), |cq| cq.inner),
344            recv_cq: recv_cq.map_or(std::ptr::null_mut(), |cq| cq.inner),
345            cap: ibv_qp_cap {
346                max_send_wr: init_attr.max_send_wr,
347                max_recv_wr: init_attr.max_recv_wr,
348                max_send_sge: init_attr.max_send_sge,
349                max_recv_sge: init_attr.max_recv_sge,
350                max_inline_data: init_attr.max_inline_data,
351            },
352            qp_type: init_attr.qp_type.as_raw(),
353            sq_sig_all: i32::from(init_attr.sq_sig_all),
354            ..Default::default()
355        };
356        from_ret_errno(unsafe { rdma_create_qp(self.inner, pd.inner, &mut raw_attr) })?;
357        Ok(CmQueuePair {
358            qp: self.qp_raw(),
359            cm_id_raw: self.inner,
360            _pd: Arc::clone(pd),
361            _send_cq: send_cq.map(Arc::clone),
362            _recv_cq: recv_cq.map(Arc::clone),
363        })
364    }
365
366    /// Connect to a remote peer (client side).
367    pub fn connect(&self, param: &ConnParam) -> Result<()> {
368        let mut raw = param.to_raw();
369        from_ret_errno(unsafe { rdma_connect(self.inner, &mut raw) })
370    }
371
372    /// Accept an incoming connection (server side).
373    pub fn accept(&self, param: &ConnParam) -> Result<()> {
374        let mut raw = param.to_raw();
375        from_ret_errno(unsafe { rdma_accept(self.inner, &mut raw) })
376    }
377
378    /// Disconnect from the remote peer.
379    pub fn disconnect(&self) -> Result<()> {
380        from_ret_errno(unsafe { rdma_disconnect(self.inner) })
381    }
382
383    /// Get the QP number (if a QP has been created on this CM ID).
384    pub fn qp_num(&self) -> Option<u32> {
385        let qp = unsafe { (*self.inner).qp };
386        if qp.is_null() {
387            None
388        } else {
389            Some(unsafe { (*qp).qp_num })
390        }
391    }
392
393    /// Get the raw QP pointer (if a QP has been created on this CM ID).
394    pub fn qp_raw(&self) -> *mut ibv_qp {
395        unsafe { (*self.inner).qp }
396    }
397
398    /// Get the ibverbs context associated with this CM ID (set after addr resolution).
399    ///
400    /// Returns `None` if the address hasn't been resolved yet.
401    pub fn verbs_context(&self) -> Option<Arc<Context>> {
402        let ctx = unsafe { (*self.inner).verbs };
403        if ctx.is_null() {
404            None
405        } else {
406            // rdma_cm owns this context — don't close on drop
407            Some(Arc::new(unsafe { Context::from_raw(ctx, false) }))
408        }
409    }
410
411    /// Allocate a PD from this CM ID's verbs context.
412    ///
413    /// Convenience for `ProtectionDomain::new(cm_id.verbs_context())`.
414    pub fn alloc_pd(&self) -> Result<Arc<ProtectionDomain>> {
415        let ctx = self.verbs_context().ok_or(crate::Error::InvalidArg(
416            "CM ID has no verbs context (resolve_addr first)".into(),
417        ))?;
418        ProtectionDomain::new(ctx)
419    }
420
421    /// Raw pointer.
422    pub fn as_raw(&self) -> *mut rdma_cm_id {
423        self.inner
424    }
425
426    /// Migrate this CM ID to a different event channel.
427    ///
428    /// After migration, events for this CM ID will be reported on the new channel.
429    /// This is typically used after `accept()` to give the accepted connection
430    /// its own event channel, independent of the listener.
431    pub fn migrate(&self, new_channel: &EventChannel) -> Result<()> {
432        from_ret_errno(unsafe { rdma_migrate_id(self.inner, new_channel.as_raw()) })
433    }
434
435    /// Get the peer's socket address (remote end of the connection).
436    ///
437    /// Returns `None` if the address is not available (e.g., not yet connected).
438    pub fn peer_addr(&self) -> Option<SocketAddr> {
439        // rdma_get_peer_addr is an inline function: &id->route.addr.dst_addr
440        let sa = unsafe { &(*self.inner).route.addr.rdma_addr__anon_1.dst_addr };
441        unsafe { sockaddr_to_std(sa as *const _ as *const _) }
442    }
443
444    /// Get the local socket address.
445    ///
446    /// Returns `None` if the address is not available.
447    pub fn local_addr(&self) -> Option<SocketAddr> {
448        // rdma_get_local_addr is an inline function: &id->route.addr.src_addr
449        let sa = unsafe { &(*self.inner).route.addr.rdma_addr__anon_0.src_addr };
450        unsafe { sockaddr_to_std(sa as *const _ as *const _) }
451    }
452}
453
454// ---------------------------------------------------------------------------
455// CmQueuePair — safe wrapper for CM-managed QPs
456// ---------------------------------------------------------------------------
457
458/// A Queue Pair created via `rdma_create_qp` on a [`CmId`].
459///
460/// Owns the QP lifecycle. [`Drop`] calls `rdma_destroy_qp`.
461/// Captures `Arc` references to PD and CQs to prevent premature destruction
462/// of resources the QP depends on.
463///
464/// # Teardown
465///
466/// Must be dropped **before** the [`CmId`] that created it (since
467/// `rdma_destroy_id` requires the QP to be destroyed first). When used
468/// inside [`AsyncQp`], field declaration order guarantees this.
469pub struct CmQueuePair {
470    qp: *mut ibv_qp,
471    cm_id_raw: *mut rdma_cm_id,
472    _pd: Arc<ProtectionDomain>,
473    _send_cq: Option<Arc<CompletionQueue>>,
474    _recv_cq: Option<Arc<CompletionQueue>>,
475}
476
477// Safety: ibv_qp is thread-safe (protected by internal locking in libibverbs).
478unsafe impl Send for CmQueuePair {}
479unsafe impl Sync for CmQueuePair {}
480
481impl Drop for CmQueuePair {
482    fn drop(&mut self) {
483        // Safety: QP was created by rdma_create_qp on this cm_id.
484        // Arc refs to PD/CQs keep them alive until after this returns.
485        unsafe { rdma_destroy_qp(self.cm_id_raw) };
486    }
487}
488
489impl CmQueuePair {
490    /// Raw QP pointer for posting work requests.
491    pub fn as_raw(&self) -> *mut ibv_qp {
492        self.qp
493    }
494
495    /// QP number assigned by the HCA.
496    pub fn qp_num(&self) -> u32 {
497        unsafe { (*self.qp).qp_num }
498    }
499}
500
501// --- Socket address helpers ---
502
503const AF_INET: u16 = 2;
504const AF_INET6: u16 = 10;
505
506/// Convert a `SocketAddr` to a `sockaddr_storage`-sized buffer.
507fn to_sockaddr_storage(addr: &SocketAddr) -> SockAddrBuf {
508    let mut buf = [0u8; std::mem::size_of::<bnd_linux::libc::posix::socket::sockaddr_storage>()];
509    match addr {
510        SocketAddr::V4(v4) => {
511            let sa = bnd_linux::libc::posix::inet::sockaddr_in {
512                sin_family: AF_INET,
513                sin_port: v4.port().to_be(),
514                sin_addr: bnd_linux::libc::posix::inet::in_addr {
515                    s_addr: u32::from_ne_bytes(v4.ip().octets()),
516                },
517                ..Default::default()
518            };
519            unsafe {
520                std::ptr::copy_nonoverlapping(
521                    &sa as *const _ as *const u8,
522                    buf.as_mut_ptr(),
523                    std::mem::size_of_val(&sa),
524                );
525            }
526        }
527        SocketAddr::V6(v6) => {
528            let sa = bnd_linux::libc::posix::inet::sockaddr_in6 {
529                sin6_family: AF_INET6,
530                sin6_port: v6.port().to_be(),
531                sin6_flowinfo: v6.flowinfo(),
532                sin6_addr: bnd_linux::libc::posix::inet::in6_addr {
533                    __in6_u: bnd_linux::libc::posix::inet::in6_addr___in6_u {
534                        __u6_addr8: v6.ip().octets(),
535                    },
536                },
537                sin6_scope_id: v6.scope_id(),
538            };
539            unsafe {
540                std::ptr::copy_nonoverlapping(
541                    &sa as *const _ as *const u8,
542                    buf.as_mut_ptr(),
543                    std::mem::size_of_val(&sa),
544                );
545            }
546        }
547    }
548    SockAddrBuf(buf)
549}
550
551/// Stack-allocated sockaddr buffer.
552struct SockAddrBuf([u8; std::mem::size_of::<bnd_linux::libc::posix::socket::sockaddr_storage>()]);
553
554impl SockAddrBuf {
555    fn as_ptr(&self) -> *const bnd_linux::libc::posix::socket::sockaddr {
556        self.0.as_ptr().cast()
557    }
558}
559
560fn sockaddr_args(
561    src: Option<&SocketAddr>,
562    dst: &SocketAddr,
563) -> (*mut bnd_linux::libc::posix::socket::sockaddr, SockAddrBuf) {
564    let dst_sa = to_sockaddr_storage(dst);
565    let src_ptr = match src {
566        // For simplicity, pass null for src (let the kernel choose).
567        Some(_) => std::ptr::null_mut(),
568        None => std::ptr::null_mut(),
569    };
570    (src_ptr, dst_sa)
571}
572
573/// Convert a raw `sockaddr*` to a `std::net::SocketAddr`.
574///
575/// # Safety
576/// The pointer must be valid and point to a `sockaddr_in` or `sockaddr_in6`.
577unsafe fn sockaddr_to_std(
578    sa: *const bnd_linux::libc::posix::socket::sockaddr,
579) -> Option<SocketAddr> {
580    unsafe {
581        let family = (*sa).sa_family;
582        if family == AF_INET {
583            let sin = &*(sa as *const bnd_linux::libc::posix::inet::sockaddr_in);
584            let ip = std::net::Ipv4Addr::from(u32::from_be(sin.sin_addr.s_addr));
585            let port = u16::from_be(sin.sin_port);
586            Some(SocketAddr::V4(std::net::SocketAddrV4::new(ip, port)))
587        } else if family == AF_INET6 {
588            let sin6 = &*(sa as *const bnd_linux::libc::posix::inet::sockaddr_in6);
589            let ip = std::net::Ipv6Addr::from(sin6.sin6_addr.__in6_u.__u6_addr8);
590            let port = u16::from_be(sin6.sin6_port);
591            Some(SocketAddr::V6(std::net::SocketAddrV6::new(
592                ip,
593                port,
594                sin6.sin6_flowinfo,
595                sin6.sin6_scope_id,
596            )))
597        } else {
598            None
599        }
600    }
601}