1use std::io;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17use nexus_net::{ParserSink, WireStream};
18
19#[cfg(feature = "tokio-rt")]
38pub struct AsyncReadAdapter<S> {
39 inner: S,
40}
41
42#[cfg(feature = "tokio-rt")]
43impl<S> AsyncReadAdapter<S> {
44 pub fn new(inner: S) -> Self {
46 Self { inner }
47 }
48
49 pub fn get_ref(&self) -> &S {
51 &self.inner
52 }
53
54 pub fn get_mut(&mut self) -> &mut S {
56 &mut self.inner
57 }
58
59 pub fn into_inner(self) -> S {
61 self.inner
62 }
63}
64
65#[cfg(feature = "tokio-rt")]
69impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> WireStream for AsyncReadAdapter<S> {
70 fn poll_fill_into<P: ParserSink>(
71 self: Pin<&mut Self>,
72 cx: &mut Context<'_>,
73 sink: &mut P,
74 max: usize,
75 ) -> Poll<io::Result<usize>> {
76 let this = self.get_mut();
77 let spare = sink.spare();
78 if max == 0 || spare.is_empty() {
79 return Poll::Ready(Err(io::Error::new(
80 io::ErrorKind::InvalidInput,
81 "poll_fill_into called with no buffer space \
82 (max == 0 or sink.spare() is empty)",
83 )));
84 }
85 let cap = spare.len().min(max);
86 let mut tmp_buf = tokio::io::ReadBuf::new(&mut spare[..cap]);
87 match Pin::new(&mut this.inner).poll_read(cx, &mut tmp_buf) {
88 Poll::Ready(Ok(())) => {
89 let n = tmp_buf.filled().len();
90 if n > 0 {
91 sink.filled(n);
92 }
93 Poll::Ready(Ok(n))
94 }
95 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
96 Poll::Pending => Poll::Pending,
97 }
98 }
99
100 fn poll_write(
101 self: Pin<&mut Self>,
102 cx: &mut Context<'_>,
103 buf: &[u8],
104 ) -> Poll<io::Result<usize>> {
105 let this = self.get_mut();
106 Pin::new(&mut this.inner).poll_write(cx, buf)
107 }
108
109 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
110 let this = self.get_mut();
111 Pin::new(&mut this.inner).poll_flush(cx)
112 }
113
114 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
115 let this = self.get_mut();
116 Pin::new(&mut this.inner).poll_shutdown(cx)
117 }
118}
119
120#[cfg(feature = "nexus")]
140pub struct NexusAsyncReadAdapter<S> {
141 inner: S,
142}
143
144#[cfg(feature = "nexus")]
145impl<S> NexusAsyncReadAdapter<S> {
146 pub fn new(inner: S) -> Self {
148 Self { inner }
149 }
150
151 pub fn get_ref(&self) -> &S {
153 &self.inner
154 }
155
156 pub fn get_mut(&mut self) -> &mut S {
158 &mut self.inner
159 }
160
161 pub fn into_inner(self) -> S {
163 self.inner
164 }
165}
166
167#[cfg(feature = "nexus")]
171impl<S: nexus_async_rt::AsyncRead + nexus_async_rt::AsyncWrite + Unpin> WireStream
172 for NexusAsyncReadAdapter<S>
173{
174 fn poll_fill_into<P: ParserSink>(
175 self: Pin<&mut Self>,
176 cx: &mut Context<'_>,
177 sink: &mut P,
178 max: usize,
179 ) -> Poll<io::Result<usize>> {
180 let this = self.get_mut();
181 let spare = sink.spare();
182 if max == 0 || spare.is_empty() {
183 return Poll::Ready(Err(io::Error::new(
184 io::ErrorKind::InvalidInput,
185 "poll_fill_into called with no buffer space \
186 (max == 0 or sink.spare() is empty)",
187 )));
188 }
189 let cap = spare.len().min(max);
190 match Pin::new(&mut this.inner).poll_read(cx, &mut spare[..cap]) {
191 Poll::Ready(Ok(n)) => {
192 if n > 0 {
193 sink.filled(n);
194 }
195 Poll::Ready(Ok(n))
196 }
197 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
198 Poll::Pending => Poll::Pending,
199 }
200 }
201
202 fn poll_write(
203 self: Pin<&mut Self>,
204 cx: &mut Context<'_>,
205 buf: &[u8],
206 ) -> Poll<io::Result<usize>> {
207 let this = self.get_mut();
208 Pin::new(&mut this.inner).poll_write(cx, buf)
209 }
210
211 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
212 let this = self.get_mut();
213 Pin::new(&mut this.inner).poll_flush(cx)
214 }
215
216 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
217 let this = self.get_mut();
218 Pin::new(&mut this.inner).poll_shutdown(cx)
219 }
220}
221
222#[cfg(test)]
227mod tests {
228 use super::*;
229 use std::future::poll_fn;
230
231 struct StubSink {
233 buf: Vec<u8>,
234 committed: usize,
235 }
236
237 impl StubSink {
238 fn with_capacity(cap: usize) -> Self {
239 Self {
240 buf: vec![0u8; cap],
241 committed: 0,
242 }
243 }
244 }
245
246 impl ParserSink for StubSink {
247 fn spare(&mut self) -> &mut [u8] {
248 &mut self.buf[self.committed..]
249 }
250 fn filled(&mut self, n: usize) {
251 self.committed += n;
252 }
253 }
254
255 struct UnpolledStream;
258
259 #[cfg(feature = "tokio-rt")]
264 impl tokio::io::AsyncRead for UnpolledStream {
265 fn poll_read(
266 self: Pin<&mut Self>,
267 _cx: &mut Context<'_>,
268 _buf: &mut tokio::io::ReadBuf<'_>,
269 ) -> Poll<io::Result<()>> {
270 panic!("UnpolledStream::poll_read should not be reached")
271 }
272 }
273
274 #[cfg(feature = "tokio-rt")]
275 impl tokio::io::AsyncWrite for UnpolledStream {
276 fn poll_write(
277 self: Pin<&mut Self>,
278 _cx: &mut Context<'_>,
279 _buf: &[u8],
280 ) -> Poll<io::Result<usize>> {
281 panic!("unreached")
282 }
283 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
284 panic!("unreached")
285 }
286 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
287 panic!("unreached")
288 }
289 }
290
291 #[cfg(feature = "tokio-rt")]
293 #[tokio::test]
294 async fn tokio_adapter_empty_spare_returns_invalid_input() {
295 let mut adapter = AsyncReadAdapter::new(UnpolledStream);
296 let mut sink = StubSink::with_capacity(0);
297 let err = poll_fn(|cx| Pin::new(&mut adapter).poll_fill_into(cx, &mut sink, 8192))
298 .await
299 .expect_err("must error on empty sink");
300 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
301 }
302
303 #[cfg(feature = "tokio-rt")]
305 #[tokio::test]
306 async fn tokio_adapter_max_zero_returns_invalid_input() {
307 let mut adapter = AsyncReadAdapter::new(UnpolledStream);
308 let mut sink = StubSink::with_capacity(64);
309 let err = poll_fn(|cx| Pin::new(&mut adapter).poll_fill_into(cx, &mut sink, 0))
310 .await
311 .expect_err("must error on max == 0");
312 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
313 }
314
315 #[cfg(feature = "nexus")]
320 impl nexus_async_rt::AsyncRead for UnpolledStream {
321 fn poll_read(
322 self: Pin<&mut Self>,
323 _cx: &mut Context<'_>,
324 _buf: &mut [u8],
325 ) -> Poll<io::Result<usize>> {
326 panic!("UnpolledStream::poll_read should not be reached")
327 }
328 }
329
330 #[cfg(feature = "nexus")]
331 impl nexus_async_rt::AsyncWrite for UnpolledStream {
332 fn poll_write(
333 self: Pin<&mut Self>,
334 _cx: &mut Context<'_>,
335 _buf: &[u8],
336 ) -> Poll<io::Result<usize>> {
337 panic!("unreached")
338 }
339 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
340 panic!("unreached")
341 }
342 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
343 panic!("unreached")
344 }
345 }
346
347 #[cfg(feature = "nexus")]
350 fn block_on<F: std::future::Future>(f: F) -> F::Output {
351 use std::task::{RawWaker, RawWakerVTable, Waker};
352 fn noop(_: *const ()) {}
353 fn noop_clone(p: *const ()) -> RawWaker {
354 RawWaker::new(p, &VTABLE)
355 }
356 const VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
357 let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) };
361 let mut cx = Context::from_waker(&waker);
362 let mut f = std::pin::pin!(f);
363 match f.as_mut().poll(&mut cx) {
364 Poll::Ready(v) => v,
365 Poll::Pending => panic!("precondition error must be synchronous"),
366 }
367 }
368
369 #[cfg(feature = "nexus")]
371 #[test]
372 fn nexus_adapter_empty_spare_returns_invalid_input() {
373 let mut adapter = NexusAsyncReadAdapter::new(UnpolledStream);
374 let mut sink = StubSink::with_capacity(0);
375 let err = block_on(poll_fn(|cx| {
376 Pin::new(&mut adapter).poll_fill_into(cx, &mut sink, 8192)
377 }))
378 .expect_err("must error on empty sink");
379 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
380 }
381
382 #[cfg(feature = "nexus")]
384 #[test]
385 fn nexus_adapter_max_zero_returns_invalid_input() {
386 let mut adapter = NexusAsyncReadAdapter::new(UnpolledStream);
387 let mut sink = StubSink::with_capacity(64);
388 let err = block_on(poll_fn(|cx| {
389 Pin::new(&mut adapter).poll_fill_into(cx, &mut sink, 0)
390 }))
391 .expect_err("must error on max == 0");
392 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
393 }
394}