Skip to main content

nexus_async_net/
wire.rs

1//! Backend-specific [`WireStream`](nexus_net::WireStream) adapters.
2//!
3//! `WireStream`/`ParserSink` are defined in `nexus-net`; the adapters
4//! that wrap a runtime's `AsyncRead+AsyncWrite` source live here so
5//! they can use the runtime's own trait shape without forcing
6//! `nexus-net` to depend on the runtime.
7//!
8//! - [`AsyncReadAdapter`] wraps `tokio::io::AsyncRead+AsyncWrite`
9//!   (under `feature = "tokio-rt"`).
10//! - [`NexusAsyncReadAdapter`] wraps `nexus_async_rt::AsyncRead+AsyncWrite`
11//!   (under `feature = "nexus"`).
12
13use std::io;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17use nexus_net::{ParserSink, WireStream};
18
19// =============================================================================
20// Tokio adapter
21// =============================================================================
22
23/// Wraps a `tokio::io::AsyncRead + AsyncWrite` source as a [`WireStream`].
24///
25/// Use this when constructing `WsStream`/`HttpConnection` over a custom
26/// tokio transport (raw `TcpStream`, mock streams, etc.). The canonical
27/// `MaybeTls` transport implements `WireStream` directly.
28///
29/// ```ignore
30/// use nexus_async_net::AsyncReadAdapter;
31///
32/// let tcp = tokio::net::TcpStream::connect(addr).await?;
33/// let ws = WsStreamBuilder::new()
34///     .connect_with(AsyncReadAdapter::new(tcp), url)
35///     .await?;
36/// ```
37#[cfg(feature = "tokio-rt")]
38pub struct AsyncReadAdapter<S> {
39    inner: S,
40}
41
42#[cfg(feature = "tokio-rt")]
43impl<S> AsyncReadAdapter<S> {
44    /// Wrap an inner `AsyncRead+AsyncWrite` stream.
45    pub fn new(inner: S) -> Self {
46        Self { inner }
47    }
48
49    /// Access the inner stream.
50    pub fn get_ref(&self) -> &S {
51        &self.inner
52    }
53
54    /// Mutable access to the inner stream.
55    pub fn get_mut(&mut self) -> &mut S {
56        &mut self.inner
57    }
58
59    /// Decompose into the inner stream.
60    pub fn into_inner(self) -> S {
61        self.inner
62    }
63}
64
65// SAFETY note: structural pinning of `inner`. We project
66// `Pin<&mut Self> -> Pin<&mut S>` and never move out; `Self` has no
67// `Drop` impl that could observe pinned state.
68#[cfg(feature = "tokio-rt")]
69impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> WireStream for AsyncReadAdapter<S> {
70    fn poll_fill_into<P: ParserSink>(
71        self: Pin<&mut Self>,
72        cx: &mut Context<'_>,
73        sink: &mut P,
74        max: usize,
75    ) -> Poll<io::Result<usize>> {
76        let this = self.get_mut();
77        let spare = sink.spare();
78        if max == 0 || spare.is_empty() {
79            return Poll::Ready(Err(io::Error::new(
80                io::ErrorKind::InvalidInput,
81                "poll_fill_into called with no buffer space \
82                 (max == 0 or sink.spare() is empty)",
83            )));
84        }
85        let cap = spare.len().min(max);
86        let mut tmp_buf = tokio::io::ReadBuf::new(&mut spare[..cap]);
87        match Pin::new(&mut this.inner).poll_read(cx, &mut tmp_buf) {
88            Poll::Ready(Ok(())) => {
89                let n = tmp_buf.filled().len();
90                if n > 0 {
91                    sink.filled(n);
92                }
93                Poll::Ready(Ok(n))
94            }
95            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
96            Poll::Pending => Poll::Pending,
97        }
98    }
99
100    fn poll_write(
101        self: Pin<&mut Self>,
102        cx: &mut Context<'_>,
103        buf: &[u8],
104    ) -> Poll<io::Result<usize>> {
105        let this = self.get_mut();
106        Pin::new(&mut this.inner).poll_write(cx, buf)
107    }
108
109    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
110        let this = self.get_mut();
111        Pin::new(&mut this.inner).poll_flush(cx)
112    }
113
114    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
115        let this = self.get_mut();
116        Pin::new(&mut this.inner).poll_shutdown(cx)
117    }
118}
119
120// =============================================================================
121// Nexus runtime adapter
122// =============================================================================
123
124/// Wraps a `nexus_async_rt::AsyncRead + AsyncWrite` source as a
125/// [`WireStream`].
126///
127/// Use this when constructing `WsStream`/`HttpConnection` over a
128/// custom transport on the nexus-async-rt backend. The canonical
129/// `MaybeTls` path implements `WireStream` directly with a faster
130/// zero-copy plaintext path; this adapter is the slow-path equivalent
131/// for arbitrary transports.
132///
133/// ```ignore
134/// use nexus_async_net::NexusAsyncReadAdapter;
135///
136/// let adapter = NexusAsyncReadAdapter::new(my_custom_stream);
137/// let ws = WsStreamBuilder::new().connect_with(adapter, url).await?;
138/// ```
139#[cfg(feature = "nexus")]
140pub struct NexusAsyncReadAdapter<S> {
141    inner: S,
142}
143
144#[cfg(feature = "nexus")]
145impl<S> NexusAsyncReadAdapter<S> {
146    /// Wrap an inner stream.
147    pub fn new(inner: S) -> Self {
148        Self { inner }
149    }
150
151    /// Access the inner stream.
152    pub fn get_ref(&self) -> &S {
153        &self.inner
154    }
155
156    /// Mutable access to the inner stream.
157    pub fn get_mut(&mut self) -> &mut S {
158        &mut self.inner
159    }
160
161    /// Decompose into the inner stream.
162    pub fn into_inner(self) -> S {
163        self.inner
164    }
165}
166
167// SAFETY note: structural pinning of `inner`. We only project
168// `Pin<&mut Self> -> Pin<&mut S>` and never move out; `Self` has no
169// `Drop` impl that could observe pinned state.
170#[cfg(feature = "nexus")]
171impl<S: nexus_async_rt::AsyncRead + nexus_async_rt::AsyncWrite + Unpin> WireStream
172    for NexusAsyncReadAdapter<S>
173{
174    fn poll_fill_into<P: ParserSink>(
175        self: Pin<&mut Self>,
176        cx: &mut Context<'_>,
177        sink: &mut P,
178        max: usize,
179    ) -> Poll<io::Result<usize>> {
180        let this = self.get_mut();
181        let spare = sink.spare();
182        if max == 0 || spare.is_empty() {
183            return Poll::Ready(Err(io::Error::new(
184                io::ErrorKind::InvalidInput,
185                "poll_fill_into called with no buffer space \
186                 (max == 0 or sink.spare() is empty)",
187            )));
188        }
189        let cap = spare.len().min(max);
190        match Pin::new(&mut this.inner).poll_read(cx, &mut spare[..cap]) {
191            Poll::Ready(Ok(n)) => {
192                if n > 0 {
193                    sink.filled(n);
194                }
195                Poll::Ready(Ok(n))
196            }
197            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
198            Poll::Pending => Poll::Pending,
199        }
200    }
201
202    fn poll_write(
203        self: Pin<&mut Self>,
204        cx: &mut Context<'_>,
205        buf: &[u8],
206    ) -> Poll<io::Result<usize>> {
207        let this = self.get_mut();
208        Pin::new(&mut this.inner).poll_write(cx, buf)
209    }
210
211    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
212        let this = self.get_mut();
213        Pin::new(&mut this.inner).poll_flush(cx)
214    }
215
216    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
217        let this = self.get_mut();
218        Pin::new(&mut this.inner).poll_shutdown(cx)
219    }
220}
221
222// =============================================================================
223// Tests
224// =============================================================================
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use std::future::poll_fn;
230
231    /// `ParserSink` whose spare is configurable for the precondition tests.
232    struct StubSink {
233        buf: Vec<u8>,
234        committed: usize,
235    }
236
237    impl StubSink {
238        fn with_capacity(cap: usize) -> Self {
239            Self {
240                buf: vec![0u8; cap],
241                committed: 0,
242            }
243        }
244    }
245
246    impl ParserSink for StubSink {
247        fn spare(&mut self) -> &mut [u8] {
248            &mut self.buf[self.committed..]
249        }
250        fn filled(&mut self, n: usize) {
251            self.committed += n;
252        }
253    }
254
255    /// Stream stub that panics if polled — proves the precondition
256    /// error fires before any I/O is attempted.
257    struct UnpolledStream;
258
259    // -------------------------------------------------------------------------
260    // Tokio adapter
261    // -------------------------------------------------------------------------
262
263    #[cfg(feature = "tokio-rt")]
264    impl tokio::io::AsyncRead for UnpolledStream {
265        fn poll_read(
266            self: Pin<&mut Self>,
267            _cx: &mut Context<'_>,
268            _buf: &mut tokio::io::ReadBuf<'_>,
269        ) -> Poll<io::Result<()>> {
270            panic!("UnpolledStream::poll_read should not be reached")
271        }
272    }
273
274    #[cfg(feature = "tokio-rt")]
275    impl tokio::io::AsyncWrite for UnpolledStream {
276        fn poll_write(
277            self: Pin<&mut Self>,
278            _cx: &mut Context<'_>,
279            _buf: &[u8],
280        ) -> Poll<io::Result<usize>> {
281            panic!("unreached")
282        }
283        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
284            panic!("unreached")
285        }
286        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
287            panic!("unreached")
288        }
289    }
290
291    /// Empty-spare precondition fires before the stream is polled.
292    #[cfg(feature = "tokio-rt")]
293    #[tokio::test]
294    async fn tokio_adapter_empty_spare_returns_invalid_input() {
295        let mut adapter = AsyncReadAdapter::new(UnpolledStream);
296        let mut sink = StubSink::with_capacity(0);
297        let err = poll_fn(|cx| Pin::new(&mut adapter).poll_fill_into(cx, &mut sink, 8192))
298            .await
299            .expect_err("must error on empty sink");
300        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
301    }
302
303    /// `max == 0` precondition fires before the stream is polled.
304    #[cfg(feature = "tokio-rt")]
305    #[tokio::test]
306    async fn tokio_adapter_max_zero_returns_invalid_input() {
307        let mut adapter = AsyncReadAdapter::new(UnpolledStream);
308        let mut sink = StubSink::with_capacity(64);
309        let err = poll_fn(|cx| Pin::new(&mut adapter).poll_fill_into(cx, &mut sink, 0))
310            .await
311            .expect_err("must error on max == 0");
312        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
313    }
314
315    // -------------------------------------------------------------------------
316    // Nexus adapter
317    // -------------------------------------------------------------------------
318
319    #[cfg(feature = "nexus")]
320    impl nexus_async_rt::AsyncRead for UnpolledStream {
321        fn poll_read(
322            self: Pin<&mut Self>,
323            _cx: &mut Context<'_>,
324            _buf: &mut [u8],
325        ) -> Poll<io::Result<usize>> {
326            panic!("UnpolledStream::poll_read should not be reached")
327        }
328    }
329
330    #[cfg(feature = "nexus")]
331    impl nexus_async_rt::AsyncWrite for UnpolledStream {
332        fn poll_write(
333            self: Pin<&mut Self>,
334            _cx: &mut Context<'_>,
335            _buf: &[u8],
336        ) -> Poll<io::Result<usize>> {
337            panic!("unreached")
338        }
339        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
340            panic!("unreached")
341        }
342        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
343            panic!("unreached")
344        }
345    }
346
347    /// Drive a future to completion via a noop waker — the precondition
348    /// error is synchronous so no real runtime is needed.
349    #[cfg(feature = "nexus")]
350    fn block_on<F: std::future::Future>(f: F) -> F::Output {
351        use std::task::{RawWaker, RawWakerVTable, Waker};
352        fn noop(_: *const ()) {}
353        fn noop_clone(p: *const ()) -> RawWaker {
354            RawWaker::new(p, &VTABLE)
355        }
356        const VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
357        // SAFETY: The vtable functions (clone/wake/wake_by_ref/drop) are all no-ops
358        // that never dereference the data pointer, so the null data pointer is sound.
359        // The vtable is 'static (const) and correctly returns a valid RawWaker on clone.
360        let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) };
361        let mut cx = Context::from_waker(&waker);
362        let mut f = std::pin::pin!(f);
363        match f.as_mut().poll(&mut cx) {
364            Poll::Ready(v) => v,
365            Poll::Pending => panic!("precondition error must be synchronous"),
366        }
367    }
368
369    /// Empty-spare precondition fires before the stream is polled.
370    #[cfg(feature = "nexus")]
371    #[test]
372    fn nexus_adapter_empty_spare_returns_invalid_input() {
373        let mut adapter = NexusAsyncReadAdapter::new(UnpolledStream);
374        let mut sink = StubSink::with_capacity(0);
375        let err = block_on(poll_fn(|cx| {
376            Pin::new(&mut adapter).poll_fill_into(cx, &mut sink, 8192)
377        }))
378        .expect_err("must error on empty sink");
379        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
380    }
381
382    /// `max == 0` precondition fires before the stream is polled.
383    #[cfg(feature = "nexus")]
384    #[test]
385    fn nexus_adapter_max_zero_returns_invalid_input() {
386        let mut adapter = NexusAsyncReadAdapter::new(UnpolledStream);
387        let mut sink = StubSink::with_capacity(64);
388        let err = block_on(poll_fn(|cx| {
389            Pin::new(&mut adapter).poll_fill_into(cx, &mut sink, 0)
390        }))
391        .expect_err("must error on max == 0");
392        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
393    }
394}