Skip to main content

rdma_io/
async_stream.rs

1//! Async RDMA Stream — async `read` + `write` over a [`Transport`].
2//!
3//! Provides TCP-like async semantics over an RDMA transport, built on
4//! the [`Transport`](crate::transport::Transport) trait for completion-driven I/O.
5//!
6//! # Architecture
7//!
8//! `AsyncRdmaStream<T>` is generic over `T: Transport`. Callers construct
9//! the transport directly (e.g. [`SendRecvTransport::connect`] or
10//! [`CreditRingTransport::connect`]) and wrap it with [`AsyncRdmaStream::new`].
11//!
12//! [`SendRecvTransport::connect`]: crate::send_recv_transport::SendRecvTransport::connect
13//! [`CreditRingTransport::connect`]: crate::credit_ring_transport::CreditRingTransport::connect
14//!
15//! # Protocol
16//!
17//! No application-level framing. Each `write()` becomes one transport send;
18//! each `read()` consumes one recv completion.
19//!
20//! # Example
21//!
22//! ```no_run
23//! use rdma_io::async_cm::AsyncCmListener;
24//! use rdma_io::async_stream::AsyncRdmaStream;
25//! use rdma_io::send_recv_transport::{SendRecvTransport, SendRecvConfig};
26//!
27//! # async fn example() -> rdma_io::Result<()> {
28//! // Server
29//! let listener = AsyncCmListener::bind(&"0.0.0.0:9999".parse().unwrap())?;
30//! let transport = SendRecvTransport::accept(&listener, SendRecvConfig::default()).await?;
31//! let mut server = AsyncRdmaStream::new(transport);
32//!
33//! // Client
34//! let addr = "10.0.0.1:9999".parse().unwrap();
35//! let transport = SendRecvTransport::connect(&addr, SendRecvConfig::default()).await?;
36//! let mut client = AsyncRdmaStream::new(transport);
37//! # Ok(())
38//! # }
39//! ```
40
41use std::fmt;
42use std::io;
43use std::net::SocketAddr;
44use std::pin::Pin;
45use std::task::{Context, Poll};
46
47use futures_io::{AsyncRead, AsyncWrite};
48
49use crate::transport::{RecvCompletion, Transport};
50
51/// An async RDMA stream with `read` and `write` methods.
52///
53/// Generic over `T: Transport`. Construct via [`AsyncRdmaStream::new`] with
54/// a pre-built transport.
55pub struct AsyncRdmaStream<T: Transport> {
56    transport: T,
57    /// Partially consumed recv: (buf_index, offset, total_len).
58    recv_pending: Option<(usize, usize, usize)>,
59    /// In-flight send length. None if send slot is free.
60    write_pending: Option<usize>,
61    /// Set when transport returns Err from poll_recv (QP entered ERROR state).
62    /// Once set, poll_read always returns Ok(0) — the QP will never
63    /// produce another recv completion.
64    eof: bool,
65}
66
67// T: Transport is Send + Sync, and our own fields are trivially Unpin/Send/Sync.
68impl<T: Transport> Unpin for AsyncRdmaStream<T> {}
69
70impl<T: Transport> fmt::Debug for AsyncRdmaStream<T> {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        f.debug_struct("AsyncRdmaStream")
73            .field("local_addr", &self.transport.local_addr())
74            .field("peer_addr", &self.transport.peer_addr())
75            .field("eof", &self.eof)
76            .field("recv_pending", &self.recv_pending.is_some())
77            .field("write_pending", &self.write_pending.is_some())
78            .finish()
79    }
80}
81
82impl<T: Transport> AsyncRdmaStream<T> {
83    /// Wrap a pre-constructed transport as a byte stream.
84    pub fn new(transport: T) -> Self {
85        Self {
86            transport,
87            recv_pending: None,
88            write_pending: None,
89            eof: false,
90        }
91    }
92
93    /// Read data from the stream asynchronously.
94    ///
95    /// Returns the number of bytes read. Returns `Ok(0)` on disconnect (EOF).
96    pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
97        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_read(cx, buf)).await
98    }
99
100    /// Write data to the stream asynchronously.
101    ///
102    /// Returns the number of bytes written (bounded by buffer size).
103    pub async fn write(&mut self, data: &[u8]) -> io::Result<usize> {
104        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_write(cx, data)).await
105    }
106
107    /// Write all data to the stream, looping if necessary.
108    pub async fn write_all(&mut self, mut data: &[u8]) -> io::Result<()> {
109        while !data.is_empty() {
110            let n = self.write(data).await?;
111            data = &data[n..];
112        }
113        Ok(())
114    }
115
116    /// Disconnect the stream gracefully.
117    pub async fn shutdown(&mut self) -> io::Result<()> {
118        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_close(cx)).await
119    }
120
121    /// Get the peer's socket address (remote end of the connection).
122    pub fn peer_addr(&self) -> Option<SocketAddr> {
123        self.transport.peer_addr()
124    }
125
126    /// Get the local socket address.
127    pub fn local_addr(&self) -> Option<SocketAddr> {
128        self.transport.local_addr()
129    }
130}
131
132impl<T: Transport> Drop for AsyncRdmaStream<T> {
133    fn drop(&mut self) {
134        let _ = self.transport.disconnect();
135    }
136}
137
138// --- futures::io trait implementations ---
139
140impl<T: Transport> AsyncRead for AsyncRdmaStream<T> {
141    fn poll_read(
142        self: Pin<&mut Self>,
143        cx: &mut Context<'_>,
144        buf: &mut [u8],
145    ) -> Poll<io::Result<usize>> {
146        let this = self.get_mut();
147
148        if buf.is_empty() || this.eof {
149            return Poll::Ready(Ok(0));
150        }
151
152        // Phase 1: Return buffered recv data (partial read from previous completion)
153        if let Some((buf_idx, offset, total_len)) = this.recv_pending {
154            let remaining = total_len - offset;
155            let copy_len = remaining.min(buf.len());
156            buf[..copy_len]
157                .copy_from_slice(&this.transport.recv_buf(buf_idx)[offset..offset + copy_len]);
158            if copy_len < remaining {
159                this.recv_pending = Some((buf_idx, offset + copy_len, total_len));
160            } else {
161                this.recv_pending = None;
162                // Repost failure is non-fatal — data is already delivered.
163                // The buffer is lost, but the QP ERROR will be detected on next poll.
164                let _ = this.transport.repost_recv(buf_idx);
165            }
166            return Poll::Ready(Ok(copy_len));
167        }
168
169        // Phase 2: Poll transport for new recv completion.
170        // Loop handles credit-only batches (Ok(0)) from ring transport —
171        // credits update internal state but produce no data. Re-polling
172        // ensures the CQ waker is properly registered when no data remains.
173        let mut completions = [RecvCompletion::default(); 1];
174        loop {
175            match this.transport.poll_recv(cx, &mut completions) {
176                Poll::Pending => {
177                    if this.transport.poll_disconnect(cx) {
178                        this.eof = true;
179                        return Poll::Ready(Ok(0));
180                    }
181                    return Poll::Pending;
182                }
183                Poll::Ready(Err(_)) => {
184                    // FLUSH_ERR etc. — transport marked itself dead
185                    this.eof = true;
186                    return Poll::Ready(Ok(0));
187                }
188                Poll::Ready(Ok(0)) => {
189                    // Credit-only or internal-state-only batch — re-poll.
190                    continue;
191                }
192                Poll::Ready(Ok(_)) => {
193                    let c = &completions[0];
194                    if c.byte_len == 0 {
195                        return Poll::Ready(Ok(0));
196                    }
197                    let copy_len = c.byte_len.min(buf.len());
198                    buf[..copy_len]
199                        .copy_from_slice(&this.transport.recv_buf(c.buf_idx)[..copy_len]);
200                    if copy_len < c.byte_len {
201                        this.recv_pending = Some((c.buf_idx, copy_len, c.byte_len));
202                    } else {
203                        let _ = this.transport.repost_recv(c.buf_idx);
204                    }
205                    return Poll::Ready(Ok(copy_len));
206                }
207            }
208        }
209    }
210}
211
212impl<T: Transport> AsyncWrite for AsyncRdmaStream<T> {
213    fn poll_write(
214        self: Pin<&mut Self>,
215        cx: &mut Context<'_>,
216        buf: &[u8],
217    ) -> Poll<io::Result<usize>> {
218        let this = self.get_mut();
219
220        if buf.is_empty() {
221            return Poll::Ready(Ok(0));
222        }
223
224        if this.eof {
225            this.write_pending = None;
226            return Poll::Ready(Err(io::Error::new(
227                io::ErrorKind::BrokenPipe,
228                "connection closed",
229            )));
230        }
231
232        // Post send if not already in progress
233        if this.write_pending.is_none() {
234            match this.transport.send_copy(buf) {
235                Ok(0) => {
236                    // All send buffers occupied or credits exhausted.
237                    // Poll recv CQ to register a waker for incoming credit
238                    // updates (ring transport sends credits via Send+Imm on
239                    // the recv CQ). Without this, credit-blocked writes deadlock
240                    // because poll_send_completion only watches the send CQ.
241                    let mut completions = [RecvCompletion::default(); 1];
242                    let _ = this.transport.poll_recv(cx, &mut completions);
243                }
244                Ok(n) => {
245                    this.write_pending = Some(n);
246                }
247                Err(e) => return Poll::Ready(Err(io::Error::other(e))),
248            }
249        }
250
251        // If we haven't posted yet (buffers full), wait then retry
252        if this.write_pending.is_none() {
253            match this.transport.poll_send_completion(cx) {
254                Poll::Pending => {
255                    if this.transport.poll_disconnect(cx) {
256                        this.eof = true;
257                        return Poll::Ready(Err(io::Error::new(
258                            io::ErrorKind::BrokenPipe,
259                            "connection closed",
260                        )));
261                    }
262                    return Poll::Pending;
263                }
264                Poll::Ready(Err(e)) => {
265                    this.eof = true;
266                    return Poll::Ready(Err(io::Error::other(e)));
267                }
268                Poll::Ready(Ok(())) => match this.transport.send_copy(buf) {
269                    Ok(0) => {
270                        // Still blocked (credit exhaustion for ring transport).
271                        // Poll recv CQ to register a waker — credit updates
272                        // arrive as Send+Imm on the recv CQ, not the send CQ.
273                        let mut completions = [RecvCompletion::default(); 1];
274                        let _ = this.transport.poll_recv(cx, &mut completions);
275                        return Poll::Pending;
276                    }
277                    Ok(n) => this.write_pending = Some(n),
278                    Err(e) => return Poll::Ready(Err(io::Error::other(e))),
279                },
280            }
281        }
282        let len = this.write_pending.unwrap();
283
284        // Wait for THIS send's completion
285        match this.transport.poll_send_completion(cx) {
286            Poll::Pending => {
287                if this.transport.poll_disconnect(cx) {
288                    this.eof = true;
289                    this.write_pending = None;
290                    return Poll::Ready(Err(io::Error::new(
291                        io::ErrorKind::BrokenPipe,
292                        "connection closed",
293                    )));
294                }
295                Poll::Pending
296            }
297            Poll::Ready(Err(e)) => {
298                this.eof = true;
299                this.write_pending = None;
300                Poll::Ready(Err(io::Error::other(e)))
301            }
302            Poll::Ready(Ok(())) => {
303                this.write_pending = None;
304                Poll::Ready(Ok(len))
305            }
306        }
307    }
308
309    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
310        Poll::Ready(Ok(()))
311    }
312
313    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
314        let this = self.get_mut();
315
316        if this.eof {
317            this.write_pending = None;
318            return Poll::Ready(Ok(()));
319        }
320
321        // Phase 1: Drain pending send completion before disconnecting.
322        if this.write_pending.is_some() {
323            match this.transport.poll_send_completion(cx) {
324                Poll::Pending => {
325                    if this.transport.poll_disconnect(cx) {
326                        this.eof = true;
327                        this.write_pending = None;
328                        return Poll::Ready(Ok(()));
329                    }
330                    return Poll::Pending;
331                }
332                Poll::Ready(_) => {
333                    this.write_pending = None;
334                }
335            }
336        }
337
338        // Phase 2: Send DREQ and complete.
339        //
340        // We don't wait for the peer's DREP or DISCONNECTED event.
341        // On rxe, the peer may not process DREQ promptly (e.g. idle
342        // server), causing an 80+ second CM timeout. The kernel handles
343        // the DREP exchange asynchronously, and Drop performs final
344        // QP/CQ cleanup.
345        let _ = this.transport.disconnect();
346        this.eof = true;
347        Poll::Ready(Ok(()))
348    }
349}