read_write_ext_tokio/
lib.rs

1//! [![crates.io version](https://img.shields.io/crates/v/read-write-ext-tokio.svg)](https://crates.io/crates/read-write-ext-tokio)
2//! [![license: Apache 2.0](https://gitlab.com/leonhard-llc/fixed-buffer-rs/-/raw/main/license-apache-2.0.svg)](http://www.apache.org/licenses/LICENSE-2.0)
3//! [![unsafe forbidden](https://gitlab.com/leonhard-llc/fixed-buffer-rs/-/raw/main/unsafe-forbidden-success.svg)](https://github.com/rust-secure-code/safety-dance/)
4//! [![pipeline status](https://gitlab.com/leonhard-llc/fixed-buffer-rs/badges/main/pipeline.svg)](https://gitlab.com/leonhard-llc/fixed-buffer-rs/-/pipelines)
5//!
6//! `AsyncReadWriteExt` trait with `chain_after` and `take_rw` for `tokio::io::Read + Write` structs.
7//!
8//! # Features
9//! - `forbid(unsafe_code)`
10//! - Good test coverage (100%)
11//! - Like [`tokio::io::AsyncReadExt::chain`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.chain)
12//!   and [`tokio::io::AsyncReadExt::take`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.take)
13//!   but also passes through writes.
14//! - Useful with `Read + Write` objects like
15//!   [`tokio::net::TcpStream`](https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html)
16//!   and [`tokio_rustls::TlsStream`](https://docs.rs/tokio-rustls/latest/tokio_rustls/enum.TlsStream.html).
17//!
18//! # Changelog
19//! - v1.0.0 - Stable API.
20//! - v0.1.0 - Initial release.  Moved code from `fixed-buffer-tokio`.
21//!
22//! # TO DO
23//!
24//! # Release Process
25//! 1. Edit `Cargo.toml` and bump version number.
26//! 1. Run `../release.sh`
27#![forbid(unsafe_code)]
28
29use std::pin::Pin;
30use std::task::{Context, Poll};
31use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
32
33/// A wrapper for a pair of structs, R and RW.
34///
35/// Implements `tokio::io::AsyncRead`.  Reads from `R` until it is empty, then reads from `RW`.
36///
37/// Implements `tokio::io::AsyncWrite`.  Passes all writes through to `RW`.
38///
39/// This is like the struct returned by
40/// [`tokio::io::AsyncReadExt::chain`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.chain)
41/// that also passes through writes.
42pub struct AsyncReadWriteChain<R: AsyncRead, RW: AsyncRead + AsyncWrite> {
43    reader: Option<R>,
44    read_writer: RW,
45}
46impl<R: AsyncRead, RW: AsyncRead + AsyncWrite> AsyncReadWriteChain<R, RW> {
47    /// See [`AsyncReadWriteChain`](struct.AsyncReadWriteChain.html).
48    pub fn new(reader: R, read_writer: RW) -> AsyncReadWriteChain<R, RW> {
49        Self {
50            reader: Some(reader),
51            read_writer,
52        }
53    }
54}
55impl<R: AsyncRead + Unpin, RW: AsyncRead + AsyncWrite + Unpin> AsyncRead
56    for AsyncReadWriteChain<R, RW>
57{
58    fn poll_read(
59        mut self: Pin<&mut Self>,
60        cx: &mut Context<'_>,
61        buf: &mut ReadBuf<'_>,
62    ) -> Poll<Result<(), std::io::Error>> {
63        if let Some(ref mut reader) = self.reader {
64            let before_len = buf.filled().len();
65            match Pin::new(&mut *reader).poll_read(cx, buf) {
66                Poll::Pending => return Poll::Pending,
67                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
68                Poll::Ready(Ok(())) => {
69                    let num_read = buf.filled().len() - before_len;
70                    if num_read > 0 {
71                        return Poll::Ready(Ok(()));
72                    }
73                    // EOF
74                    self.reader = None;
75                    // Fall through.
76                }
77            }
78        }
79        Pin::new(&mut self.read_writer).poll_read(cx, buf)
80    }
81}
82impl<R: AsyncRead + Unpin, RW: AsyncRead + AsyncWrite + Unpin> AsyncWrite
83    for AsyncReadWriteChain<R, RW>
84{
85    fn poll_write(
86        mut self: Pin<&mut Self>,
87        cx: &mut Context<'_>,
88        buf: &[u8],
89    ) -> Poll<Result<usize, std::io::Error>> {
90        Pin::new(&mut self.read_writer).poll_write(cx, buf)
91    }
92
93    fn poll_flush(
94        mut self: Pin<&mut Self>,
95        cx: &mut Context<'_>,
96    ) -> Poll<Result<(), std::io::Error>> {
97        Pin::new(&mut self.read_writer).poll_flush(cx)
98    }
99
100    fn poll_shutdown(
101        mut self: Pin<&mut Self>,
102        cx: &mut Context<'_>,
103    ) -> Poll<Result<(), std::io::Error>> {
104        Pin::new(&mut self.read_writer).poll_shutdown(cx)
105    }
106}
107
108/// Wraps a `tokio::io::AsyncRead + tokio::io::AsyncWrite` struct.
109/// Passes through reads and writes to the struct.
110/// Limits the number of bytes that can be read.
111///
112/// This is like [`tokio::io::Take`](https://docs.rs/tokio/latest/tokio/io/struct.Take.html)
113/// that also passes through writes.
114pub struct AsyncReadWriteTake<RW: AsyncRead + AsyncWrite> {
115    read_writer: RW,
116    remaining_bytes: u64,
117}
118impl<RW: AsyncRead + AsyncWrite + Unpin> AsyncReadWriteTake<RW> {
119    /// See [`AsyncReadWriteTake`](struct.AsyncReadWriteTake.html).
120    pub fn new(read_writer: RW, len: u64) -> AsyncReadWriteTake<RW> {
121        Self {
122            read_writer,
123            remaining_bytes: len,
124        }
125    }
126}
127impl<RW: AsyncRead + AsyncWrite + Unpin> AsyncRead for AsyncReadWriteTake<RW> {
128    fn poll_read(
129        mut self: Pin<&mut Self>,
130        cx: &mut Context<'_>,
131        buf: &mut ReadBuf<'_>,
132    ) -> Poll<Result<(), std::io::Error>> {
133        if self.remaining_bytes == 0 {
134            return Poll::Ready(Ok(()));
135        }
136        let num_to_read =
137            usize::try_from(self.remaining_bytes.min(buf.remaining() as u64)).unwrap_or(usize::MAX);
138        let dest = &mut buf.initialize_unfilled()[0..num_to_read];
139        let mut buf2 = ReadBuf::new(dest);
140        match Pin::new(&mut self.read_writer).poll_read(cx, &mut buf2) {
141            Poll::Ready(Ok(())) => {
142                let num_read = buf2.filled().len();
143                buf.advance(num_read);
144                self.remaining_bytes -= num_read as u64;
145                Poll::Ready(Ok(()))
146            }
147            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
148            Poll::Pending => Poll::Pending,
149        }
150    }
151}
152impl<RW: AsyncRead + AsyncWrite + Unpin> AsyncWrite for AsyncReadWriteTake<RW> {
153    fn poll_write(
154        mut self: Pin<&mut Self>,
155        cx: &mut Context<'_>,
156        buf: &[u8],
157    ) -> Poll<Result<usize, std::io::Error>> {
158        Pin::new(&mut self.read_writer).poll_write(cx, buf)
159    }
160
161    fn poll_flush(
162        mut self: Pin<&mut Self>,
163        cx: &mut Context<'_>,
164    ) -> Poll<Result<(), std::io::Error>> {
165        Pin::new(&mut self.read_writer).poll_flush(cx)
166    }
167
168    fn poll_shutdown(
169        mut self: Pin<&mut Self>,
170        cx: &mut Context<'_>,
171    ) -> Poll<Result<(), std::io::Error>> {
172        Pin::new(&mut self.read_writer).poll_shutdown(cx)
173    }
174}
175pub trait AsyncReadWriteExt: AsyncRead + AsyncWrite + Unpin {
176    /// Returns a struct that implements `tokio::io::AsyncRead` and `tokio::io::AsyncWrite`.
177    ///
178    /// It reads from `reader` until it is empty, then reads from `self`.
179    ///
180    /// It passes all writes through to `self`.
181    ///
182    /// This is like [`tokio::io::AsyncReadExt::chain`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.chain)
183    /// that also passes through writes.
184    fn chain_after<R: AsyncRead>(&mut self, reader: R) -> AsyncReadWriteChain<R, &mut Self> {
185        AsyncReadWriteChain::new(reader, self)
186    }
187
188    /// Wraps a struct that implements `tokio::io::AsyncRead` and `tokio::io::AsyncWrite`.
189    ///
190    /// The returned struct passes through reads and writes to the struct.
191    ///
192    /// It limits the number of bytes that can be read.
193    ///
194    /// This is like [`tokio::io::AsyncReadExt::take`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.take)
195    /// that also passes through writes.
196    fn take_rw(&mut self, len: u64) -> AsyncReadWriteTake<&mut Self> {
197        AsyncReadWriteTake::new(self, len)
198    }
199}
200impl<RW: AsyncRead + AsyncWrite + ?Sized + Unpin> AsyncReadWriteExt for RW {}