Skip to main content

rdma_io/
async_cm.rs

1//! Async RDMA Connection Manager — non-blocking connect/accept over rdma_cm.
2//!
3//! Provides [`AsyncCmId`] for async client connections and [`AsyncCmListener`]
4//! for async server accept, using the `rdma_event_channel` fd with tokio's
5//! `AsyncFd` for true async I/O (no `spawn_blocking`).
6//!
7//! These are the building blocks for higher-level types like `AsyncRdmaStream`,
8//! but can also be used directly with `AsyncQp` for custom RDMA patterns.
9
10use std::mem::ManuallyDrop;
11use std::net::SocketAddr;
12use std::os::unix::io::RawFd;
13use std::sync::Arc;
14
15use tokio::io::unix::AsyncFd;
16
17use crate::Result;
18use crate::cm::{CmEventType, CmId, ConnParam, EventChannel, PortSpace};
19use crate::cq::CompletionQueue;
20use crate::device::Context;
21use crate::pd::ProtectionDomain;
22use crate::qp::QpInitAttr;
23
24/// Async wrapper around an `EventChannel` for non-blocking CM event delivery.
25///
26/// Uses tokio's `AsyncFd` to await readiness on the event channel fd,
27/// then calls `try_get_event()` for non-blocking event retrieval.
28pub(crate) struct AsyncEventChannel {
29    async_fd: AsyncFd<RawFd>,
30}
31
32impl AsyncEventChannel {
33    /// Create a new async event channel wrapper.
34    ///
35    /// The underlying `EventChannel` must already be set to non-blocking mode.
36    pub(crate) fn new(ch: &EventChannel) -> Result<Self> {
37        let async_fd = AsyncFd::new(ch.fd()).map_err(crate::Error::Verbs)?;
38        Ok(Self { async_fd })
39    }
40
41    /// Wait for the next CM event, returning it when available.
42    pub(crate) async fn get_event(&self, ch: &EventChannel) -> Result<crate::cm::CmEvent> {
43        loop {
44            let mut guard = self
45                .async_fd
46                .readable()
47                .await
48                .map_err(crate::Error::Verbs)?;
49            match ch.try_get_event() {
50                Ok(ev) => return Ok(ev),
51                Err(crate::Error::WouldBlock) => {
52                    guard.clear_ready();
53                    continue;
54                }
55                Err(e) => return Err(e),
56            }
57        }
58    }
59
60    /// Wait for a specific CM event type, ack it, and return.
61    pub(crate) async fn expect_event(
62        &self,
63        ch: &EventChannel,
64        expected: CmEventType,
65    ) -> Result<()> {
66        let ev = self.get_event(ch).await?;
67        let actual = ev.event_type();
68        if actual != expected {
69            ev.ack();
70            return Err(crate::Error::InvalidArg(format!(
71                "expected {expected:?}, got {actual:?}"
72            )));
73        }
74        ev.ack();
75        Ok(())
76    }
77
78    /// Poll for a specific CM event type. Returns `Poll::Ready(Ok(()))` when
79    /// the expected event arrives and is acked, or `Poll::Pending` if not yet
80    /// ready. Suitable for use inside `poll_close` / `poll_*` trait methods.
81    #[allow(dead_code)]
82    pub(crate) fn poll_expect_event(
83        &self,
84        cx: &mut std::task::Context<'_>,
85        ch: &EventChannel,
86        expected: CmEventType,
87    ) -> std::task::Poll<Result<()>> {
88        loop {
89            let mut guard = match self.async_fd.poll_read_ready(cx) {
90                std::task::Poll::Ready(Ok(g)) => g,
91                std::task::Poll::Ready(Err(e)) => {
92                    return std::task::Poll::Ready(Err(crate::Error::Verbs(e)));
93                }
94                std::task::Poll::Pending => return std::task::Poll::Pending,
95            };
96            match ch.try_get_event() {
97                Ok(ev) => {
98                    let actual = ev.event_type();
99                    ev.ack();
100                    if actual != expected {
101                        return std::task::Poll::Ready(Err(crate::Error::InvalidArg(format!(
102                            "expected {expected:?}, got {actual:?}"
103                        ))));
104                    }
105                    return std::task::Poll::Ready(Ok(()));
106                }
107                Err(crate::Error::WouldBlock) => {
108                    guard.clear_ready();
109                    continue;
110                }
111                Err(e) => return std::task::Poll::Ready(Err(e)),
112            }
113        }
114    }
115}
116
117// ---------------------------------------------------------------------------
118// AsyncCmId — async client-side CM ID
119// ---------------------------------------------------------------------------
120
121/// An async RDMA CM ID for client-side connections.
122///
123/// Wraps `CmId` + `EventChannel` and provides async versions of the CM
124/// operations (`resolve_addr`, `resolve_route`, `connect`). Use with
125/// `AsyncQp` for direct async RDMA verb access, or use `AsyncRdmaStream`
126/// for a higher-level TCP-like interface.
127///
128/// # Example
129///
130/// ```no_run
131/// use rdma_io::async_cm::AsyncCmId;
132///
133/// # async fn example() -> rdma_io::Result<()> {
134/// let cm_id = AsyncCmId::connect_to(&"10.0.0.1:9999".parse().unwrap()).await?;
135///
136/// // Access the inner CmId for QP setup, verb operations, etc.
137/// let ctx = cm_id.verbs_context().unwrap();
138/// let pd = cm_id.alloc_pd()?;
139/// # Ok(())
140/// # }
141/// ```
142pub struct AsyncCmId {
143    // ManuallyDrop: cm_id must be destroyed before event_channel.
144    // rdma_destroy_id needs the event channel fd to still be open.
145    cm_id: ManuallyDrop<CmId>,
146    event_channel: ManuallyDrop<EventChannel>,
147}
148
149// Safety: EventChannel + CmId are Send-safe (raw pointers guarded by kernel).
150unsafe impl Send for AsyncCmId {}
151
152impl Drop for AsyncCmId {
153    fn drop(&mut self) {
154        // RDMA teardown order: CM ID first, then event channel.
155        unsafe {
156            ManuallyDrop::drop(&mut self.cm_id);
157            ManuallyDrop::drop(&mut self.event_channel);
158        }
159    }
160}
161
162impl AsyncCmId {
163    /// Create a new async CM ID on its own event channel.
164    pub fn new(port_space: PortSpace) -> Result<Self> {
165        let ch = EventChannel::new()?;
166        ch.set_nonblocking()?;
167        let cm_id = CmId::new(&ch, port_space)?;
168        Ok(Self {
169            cm_id: ManuallyDrop::new(cm_id),
170            event_channel: ManuallyDrop::new(ch),
171        })
172    }
173
174    /// Resolve the destination address asynchronously.
175    pub async fn resolve_addr(
176        &self,
177        src: Option<&SocketAddr>,
178        dst: &SocketAddr,
179        timeout_ms: i32,
180    ) -> Result<()> {
181        let async_ch = AsyncEventChannel::new(&self.event_channel)?;
182        self.cm_id.resolve_addr(src, dst, timeout_ms)?;
183        async_ch
184            .expect_event(&self.event_channel, CmEventType::AddrResolved)
185            .await
186    }
187
188    /// Resolve the route asynchronously.
189    pub async fn resolve_route(&self, timeout_ms: i32) -> Result<()> {
190        let async_ch = AsyncEventChannel::new(&self.event_channel)?;
191        self.cm_id.resolve_route(timeout_ms)?;
192        async_ch
193            .expect_event(&self.event_channel, CmEventType::RouteResolved)
194            .await
195    }
196
197    /// Perform the RDMA connect handshake asynchronously.
198    pub async fn connect(&self, param: &ConnParam) -> Result<()> {
199        let async_ch = AsyncEventChannel::new(&self.event_channel)?;
200        self.cm_id.connect(param)?;
201        async_ch
202            .expect_event(&self.event_channel, CmEventType::Established)
203            .await
204    }
205
206    /// Full async connect: resolve_addr → resolve_route → connect.
207    ///
208    /// Convenience method that performs all three CM phases.
209    pub async fn connect_to(addr: &SocketAddr) -> Result<Self> {
210        let cm = Self::new(PortSpace::Tcp)?;
211        cm.resolve_addr(None, addr, 2000).await?;
212        cm.resolve_route(2000).await?;
213        cm.connect(&ConnParam::default()).await?;
214        Ok(cm)
215    }
216
217    /// Access the inner `CmId` for QP setup, verb operations, etc.
218    pub fn cm_id(&self) -> &CmId {
219        &self.cm_id
220    }
221
222    /// Access the inner `EventChannel`.
223    pub fn event_channel(&self) -> &EventChannel {
224        &self.event_channel
225    }
226
227    // --- Delegate common CmId methods for convenience ---
228
229    /// Get the verbs context (device) associated with this CM ID.
230    pub fn verbs_context(&self) -> Option<Arc<Context>> {
231        self.cm_id.verbs_context()
232    }
233
234    /// Allocate a protection domain on this CM ID's device.
235    pub fn alloc_pd(&self) -> Result<Arc<ProtectionDomain>> {
236        self.cm_id.alloc_pd()
237    }
238
239    /// Create a QP with separate send/recv CQs on this CM ID.
240    ///
241    /// Returns an owned [`CmQueuePair`] that the caller must keep alive.
242    pub fn create_qp_with_cq(
243        &self,
244        pd: &Arc<ProtectionDomain>,
245        init_attr: &QpInitAttr,
246        send_cq: Option<&Arc<CompletionQueue>>,
247        recv_cq: Option<&Arc<CompletionQueue>>,
248    ) -> Result<crate::cm::CmQueuePair> {
249        self.cm_id
250            .create_qp_with_cq(pd, init_attr, send_cq, recv_cq)
251    }
252
253    /// Raw QP pointer (from the cm_id, for low-level use before QP is transferred).
254    pub fn qp_raw(&self) -> *mut rdma_io_sys::ibverbs::ibv_qp {
255        self.cm_id.qp_raw()
256    }
257
258    /// Disconnect the connection (fire-and-forget, synchronous).
259    pub fn disconnect(&self) -> Result<()> {
260        self.cm_id.disconnect()
261    }
262
263    /// Disconnect and await the `DISCONNECTED` event from the peer.
264    ///
265    /// This performs a graceful disconnect: sends the disconnect request,
266    /// then waits for the peer to acknowledge it. Analogous to TCP's
267    /// `shutdown()` + await FIN-ACK.
268    pub async fn disconnect_async(&self) -> Result<()> {
269        let async_ch = AsyncEventChannel::new(&self.event_channel)?;
270        self.cm_id.disconnect()?;
271        async_ch
272            .expect_event(&self.event_channel, CmEventType::Disconnected)
273            .await
274    }
275
276    /// Await the next CM event on this connection's event channel.
277    ///
278    /// Returns any event (disconnect, error, etc.). The caller must ack the
279    /// event via [`CmEvent::ack()`]. Useful for monitoring connection
280    /// lifecycle (e.g., detecting peer disconnect via `select!`).
281    pub async fn next_event(&self) -> Result<crate::cm::CmEvent> {
282        let async_ch = AsyncEventChannel::new(&self.event_channel)?;
283        async_ch.get_event(&self.event_channel).await
284    }
285
286    /// Decompose into the inner `EventChannel` and `CmId`.
287    ///
288    /// Used when transferring ownership to a higher-level type (e.g., `AsyncRdmaStream`).
289    pub fn into_parts(self) -> (EventChannel, CmId) {
290        let mut this = ManuallyDrop::new(self);
291        unsafe {
292            let cm_id = ManuallyDrop::take(&mut this.cm_id);
293            let event_channel = ManuallyDrop::take(&mut this.event_channel);
294            (event_channel, cm_id)
295        }
296    }
297}
298
299// ---------------------------------------------------------------------------
300// AsyncCmListener — async server-side listener
301// ---------------------------------------------------------------------------
302
303/// An async RDMA CM listener for accepting incoming connections.
304///
305/// Binds to a local address and provides an async `accept()` that returns
306/// an [`AsyncCmId`] for each incoming connection. The accepted ID is fully
307/// connected and migrated to its own event channel.
308///
309/// # Example
310///
311/// ```no_run
312/// use rdma_io::async_cm::AsyncCmListener;
313///
314/// # async fn example() -> rdma_io::Result<()> {
315/// let listener = AsyncCmListener::bind(&"0.0.0.0:9999".parse().unwrap())?;
316/// let cm_id = listener.accept().await?;
317///
318/// // Set up QP and start using AsyncQp
319/// let ctx = cm_id.verbs_context().unwrap();
320/// # Ok(())
321/// # }
322/// ```
323pub struct AsyncCmListener {
324    // ManuallyDrop: _cm_id must be destroyed before event_channel.
325    _cm_id: ManuallyDrop<CmId>,
326    event_channel: ManuallyDrop<EventChannel>,
327    async_ch: AsyncEventChannel,
328    /// ConnectRequest events consumed while waiting for Established.
329    /// Drained by `get_request` / `poll_get_request` before polling the fd.
330    pending_requests: std::sync::Mutex<std::collections::VecDeque<CmId>>,
331}
332
333// Safety: EventChannel + CmId are Send-safe (raw pointers guarded by kernel).
334unsafe impl Send for AsyncCmListener {}
335
336// Safety: All fields are Sync:
337// - CmId: `unsafe impl Sync` (kernel operations are atomic)
338// - EventChannel: `unsafe impl Sync` (kernel fd queue is serialized)
339// - AsyncEventChannel: contains `AsyncFd<RawFd>` which is Sync
340// - pending_requests: `Mutex<VecDeque<CmId>>` is Sync
341// Concurrent &self callers are safe; the kernel serializes event delivery.
342unsafe impl Sync for AsyncCmListener {}
343
344impl Drop for AsyncCmListener {
345    fn drop(&mut self) {
346        // RDMA teardown order: CM ID first, then event channel.
347        // async_ch has no Drop (just wraps an AsyncFd).
348        unsafe {
349            ManuallyDrop::drop(&mut self._cm_id);
350            ManuallyDrop::drop(&mut self.event_channel);
351        }
352    }
353}
354
355impl AsyncCmListener {
356    /// Bind to a local address and start listening.
357    pub fn bind(addr: &SocketAddr) -> Result<Self> {
358        Self::bind_with_backlog(addr, 128)
359    }
360
361    /// Bind with a custom backlog.
362    pub fn bind_with_backlog(addr: &SocketAddr, backlog: i32) -> Result<Self> {
363        let ch = EventChannel::new()?;
364        ch.set_nonblocking()?;
365        let async_ch = AsyncEventChannel::new(&ch)?;
366        let cm_id = CmId::new(&ch, PortSpace::Tcp)?;
367        cm_id.listen(addr, backlog)?;
368        Ok(Self {
369            _cm_id: ManuallyDrop::new(cm_id),
370            event_channel: ManuallyDrop::new(ch),
371            async_ch,
372            pending_requests: std::sync::Mutex::new(std::collections::VecDeque::new()),
373        })
374    }
375
376    /// Get the local socket address the listener is bound to.
377    ///
378    /// Useful when binding to port 0 to discover the assigned port.
379    pub fn local_addr(&self) -> Option<std::net::SocketAddr> {
380        self._cm_id.local_addr()
381    }
382
383    /// Accept an incoming connection asynchronously.
384    ///
385    /// Returns an [`AsyncCmId`] that is fully connected and migrated to its
386    /// own event channel. The caller can then set up QP resources and use
387    /// `AsyncQp` for data transfer.
388    ///
389    /// Note: QP and CQ must be set up by the caller BEFORE calling this,
390    /// or use the lower-level `accept_raw()` for more control.
391    pub async fn accept(&self) -> Result<AsyncCmId> {
392        self.accept_with_param(&ConnParam::default()).await
393    }
394
395    /// Accept with custom connection parameters.
396    ///
397    /// The accepted connection goes through: await CONNECT_REQUEST →
398    /// accept handshake → await ESTABLISHED → migrate to own event channel.
399    ///
400    /// **Important**: The caller must set up QP resources (PD, CQ, QP, MRs)
401    /// on the returned `AsyncCmId` before the connection can transfer data.
402    /// For a batteries-included experience, use `AsyncRdmaStream` instead.
403    pub async fn accept_with_param(&self, param: &ConnParam) -> Result<AsyncCmId> {
404        let conn_id = self.get_request().await?;
405
406        // Accept the connection (non-blocking kernel call)
407        conn_id.accept(param)?;
408
409        // Await ESTABLISHED, stashing interleaved ConnectRequest events.
410        loop {
411            let ev = self.async_ch.get_event(&self.event_channel).await?;
412            let etype = ev.event_type();
413            match etype {
414                CmEventType::Established => {
415                    ev.ack();
416                    break;
417                }
418                CmEventType::ConnectRequest => {
419                    let stashed_id = unsafe { CmId::from_raw(ev.cm_id_raw(), true) };
420                    ev.ack();
421                    self.pending_requests.lock().unwrap().push_back(stashed_id);
422                }
423                _ => {
424                    ev.ack();
425                    return Err(crate::Error::InvalidArg(format!(
426                        "expected Established, got {etype:?}"
427                    )));
428                }
429            }
430        }
431
432        // Migrate to its own event channel
433        let conn_ch = EventChannel::new()?;
434        conn_ch.set_nonblocking()?;
435        conn_id.migrate(&conn_ch)?;
436
437        Ok(AsyncCmId {
438            cm_id: ManuallyDrop::new(conn_id),
439            event_channel: ManuallyDrop::new(conn_ch),
440        })
441    }
442
443    /// Await the next CONNECT_REQUEST and return the raw `CmId`.
444    ///
445    /// This is the first phase of a two-phase accept. The caller can set up
446    /// QP resources on the returned `CmId`, then call
447    /// [`complete_accept`](Self::complete_accept) to finish the handshake.
448    pub async fn get_request(&self) -> Result<CmId> {
449        // Drain stashed requests from interleaved events during complete_accept.
450        if let Some(conn_id) = self.pending_requests.lock().unwrap().pop_front() {
451            return Ok(conn_id);
452        }
453        let ev = self.async_ch.get_event(&self.event_channel).await?;
454        let etype = ev.event_type();
455        if etype != CmEventType::ConnectRequest {
456            ev.ack();
457            return Err(crate::Error::InvalidArg(format!(
458                "expected ConnectRequest, got {etype:?}"
459            )));
460        }
461        let conn_id = unsafe { CmId::from_raw(ev.cm_id_raw(), true) };
462        ev.ack();
463        Ok(conn_id)
464    }
465
466    /// Poll for a CONNECT_REQUEST. Non-blocking poll variant of [`get_request`](Self::get_request).
467    ///
468    /// Returns `Poll::Ready(Ok(CmId))` when a new connection request arrives,
469    /// or `Poll::Pending` if none is ready (waker registered on CM fd).
470    ///
471    /// Used by `RdmaUdpSocket::poll_accept` to drive accept from within `poll_recv`.
472    pub fn poll_get_request(
473        &self,
474        cx: &mut std::task::Context<'_>,
475    ) -> std::task::Poll<Result<CmId>> {
476        // Drain stashed requests from interleaved events during complete_accept.
477        if let Some(conn_id) = self.pending_requests.lock().unwrap().pop_front() {
478            return std::task::Poll::Ready(Ok(conn_id));
479        }
480        loop {
481            let mut guard = match self.async_ch.async_fd.poll_read_ready(cx) {
482                std::task::Poll::Ready(Ok(g)) => g,
483                std::task::Poll::Ready(Err(e)) => {
484                    return std::task::Poll::Ready(Err(crate::Error::Verbs(e)));
485                }
486                std::task::Poll::Pending => return std::task::Poll::Pending,
487            };
488            match self.event_channel.try_get_event() {
489                Ok(ev) => {
490                    let etype = ev.event_type();
491                    if etype != CmEventType::ConnectRequest {
492                        ev.ack();
493                        return std::task::Poll::Ready(Err(crate::Error::InvalidArg(format!(
494                            "expected ConnectRequest, got {etype:?}"
495                        ))));
496                    }
497                    let conn_id = unsafe { CmId::from_raw(ev.cm_id_raw(), true) };
498                    ev.ack();
499                    return std::task::Poll::Ready(Ok(conn_id));
500                }
501                Err(crate::Error::WouldBlock) => {
502                    guard.clear_ready();
503                    continue;
504                }
505                Err(e) => return std::task::Poll::Ready(Err(e)),
506            }
507        }
508    }
509
510    /// Complete the accept handshake after QP setup.
511    ///
512    /// Second phase of a two-phase accept: sends the accept reply, awaits
513    /// ESTABLISHED, and migrates the connection to its own event channel.
514    ///
515    /// If a `ConnectRequest` event arrives while waiting for `Established`
516    /// (concurrent client connecting), it is stashed for later retrieval
517    /// by [`get_request`](Self::get_request) / [`poll_get_request`](Self::poll_get_request).
518    pub async fn complete_accept(&self, conn_id: CmId, param: &ConnParam) -> Result<AsyncCmId> {
519        conn_id.accept(param)?;
520
521        // Wait for Established, stashing any interleaved ConnectRequest events.
522        loop {
523            let ev = self.async_ch.get_event(&self.event_channel).await?;
524            let etype = ev.event_type();
525            match etype {
526                CmEventType::Established => {
527                    ev.ack();
528                    break;
529                }
530                CmEventType::ConnectRequest => {
531                    let stashed_id = unsafe { CmId::from_raw(ev.cm_id_raw(), true) };
532                    ev.ack();
533                    self.pending_requests.lock().unwrap().push_back(stashed_id);
534                }
535                _ => {
536                    ev.ack();
537                    return Err(crate::Error::InvalidArg(format!(
538                        "expected Established, got {etype:?}"
539                    )));
540                }
541            }
542        }
543
544        let conn_ch = EventChannel::new()?;
545        conn_ch.set_nonblocking()?;
546        conn_id.migrate(&conn_ch)?;
547
548        Ok(AsyncCmId {
549            cm_id: ManuallyDrop::new(conn_id),
550            event_channel: ManuallyDrop::new(conn_ch),
551        })
552    }
553
554    /// Await the next CM event on the listener's event channel.
555    ///
556    /// Returns any event (connect request, disconnect, error, etc.).
557    /// The caller must ack the event via [`CmEvent::ack()`].
558    /// For simple accept loops, prefer [`accept()`](Self::accept) instead.
559    pub async fn next_event(&self) -> Result<crate::cm::CmEvent> {
560        self.async_ch.get_event(&self.event_channel).await
561    }
562}