read_write_ext_tokio/lib.rs
1//! [](https://crates.io/crates/read-write-ext-tokio)
2//! [](http://www.apache.org/licenses/LICENSE-2.0)
3//! [](https://github.com/rust-secure-code/safety-dance/)
4//! [](https://gitlab.com/leonhard-llc/fixed-buffer-rs/-/pipelines)
5//!
6//! `AsyncReadWriteExt` trait with `chain_after` and `take_rw` for `tokio::io::Read + Write` structs.
7//!
8//! # Features
9//! - `forbid(unsafe_code)`
10//! - Good test coverage (100%)
11//! - Like [`tokio::io::AsyncReadExt::chain`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.chain)
12//! and [`tokio::io::AsyncReadExt::take`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.take)
13//! but also passes through writes.
14//! - Useful with `Read + Write` objects like
15//! [`tokio::net::TcpStream`](https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html)
16//! and [`tokio_rustls::TlsStream`](https://docs.rs/tokio-rustls/latest/tokio_rustls/enum.TlsStream.html).
17//!
18//! # Changelog
19//! - v1.0.0 - Stable API.
20//! - v0.1.0 - Initial release. Moved code from `fixed-buffer-tokio`.
21//!
22//! # TO DO
23//!
24//! # Release Process
25//! 1. Edit `Cargo.toml` and bump version number.
26//! 1. Run `../release.sh`
27#![forbid(unsafe_code)]
28
29use std::pin::Pin;
30use std::task::{Context, Poll};
31use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
32
33/// A wrapper for a pair of structs, R and RW.
34///
35/// Implements `tokio::io::AsyncRead`. Reads from `R` until it is empty, then reads from `RW`.
36///
37/// Implements `tokio::io::AsyncWrite`. Passes all writes through to `RW`.
38///
39/// This is like the struct returned by
40/// [`tokio::io::AsyncReadExt::chain`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.chain)
41/// that also passes through writes.
42pub struct AsyncReadWriteChain<R: AsyncRead, RW: AsyncRead + AsyncWrite> {
43 reader: Option<R>,
44 read_writer: RW,
45}
46impl<R: AsyncRead, RW: AsyncRead + AsyncWrite> AsyncReadWriteChain<R, RW> {
47 /// See [`AsyncReadWriteChain`](struct.AsyncReadWriteChain.html).
48 pub fn new(reader: R, read_writer: RW) -> AsyncReadWriteChain<R, RW> {
49 Self {
50 reader: Some(reader),
51 read_writer,
52 }
53 }
54}
55impl<R: AsyncRead + Unpin, RW: AsyncRead + AsyncWrite + Unpin> AsyncRead
56 for AsyncReadWriteChain<R, RW>
57{
58 fn poll_read(
59 mut self: Pin<&mut Self>,
60 cx: &mut Context<'_>,
61 buf: &mut ReadBuf<'_>,
62 ) -> Poll<Result<(), std::io::Error>> {
63 if let Some(ref mut reader) = self.reader {
64 let before_len = buf.filled().len();
65 match Pin::new(&mut *reader).poll_read(cx, buf) {
66 Poll::Pending => return Poll::Pending,
67 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
68 Poll::Ready(Ok(())) => {
69 let num_read = buf.filled().len() - before_len;
70 if num_read > 0 {
71 return Poll::Ready(Ok(()));
72 }
73 // EOF
74 self.reader = None;
75 // Fall through.
76 }
77 }
78 }
79 Pin::new(&mut self.read_writer).poll_read(cx, buf)
80 }
81}
82impl<R: AsyncRead + Unpin, RW: AsyncRead + AsyncWrite + Unpin> AsyncWrite
83 for AsyncReadWriteChain<R, RW>
84{
85 fn poll_write(
86 mut self: Pin<&mut Self>,
87 cx: &mut Context<'_>,
88 buf: &[u8],
89 ) -> Poll<Result<usize, std::io::Error>> {
90 Pin::new(&mut self.read_writer).poll_write(cx, buf)
91 }
92
93 fn poll_flush(
94 mut self: Pin<&mut Self>,
95 cx: &mut Context<'_>,
96 ) -> Poll<Result<(), std::io::Error>> {
97 Pin::new(&mut self.read_writer).poll_flush(cx)
98 }
99
100 fn poll_shutdown(
101 mut self: Pin<&mut Self>,
102 cx: &mut Context<'_>,
103 ) -> Poll<Result<(), std::io::Error>> {
104 Pin::new(&mut self.read_writer).poll_shutdown(cx)
105 }
106}
107
108/// Wraps a `tokio::io::AsyncRead + tokio::io::AsyncWrite` struct.
109/// Passes through reads and writes to the struct.
110/// Limits the number of bytes that can be read.
111///
112/// This is like [`tokio::io::Take`](https://docs.rs/tokio/latest/tokio/io/struct.Take.html)
113/// that also passes through writes.
114pub struct AsyncReadWriteTake<RW: AsyncRead + AsyncWrite> {
115 read_writer: RW,
116 remaining_bytes: u64,
117}
118impl<RW: AsyncRead + AsyncWrite + Unpin> AsyncReadWriteTake<RW> {
119 /// See [`AsyncReadWriteTake`](struct.AsyncReadWriteTake.html).
120 pub fn new(read_writer: RW, len: u64) -> AsyncReadWriteTake<RW> {
121 Self {
122 read_writer,
123 remaining_bytes: len,
124 }
125 }
126}
127impl<RW: AsyncRead + AsyncWrite + Unpin> AsyncRead for AsyncReadWriteTake<RW> {
128 fn poll_read(
129 mut self: Pin<&mut Self>,
130 cx: &mut Context<'_>,
131 buf: &mut ReadBuf<'_>,
132 ) -> Poll<Result<(), std::io::Error>> {
133 if self.remaining_bytes == 0 {
134 return Poll::Ready(Ok(()));
135 }
136 let num_to_read =
137 usize::try_from(self.remaining_bytes.min(buf.remaining() as u64)).unwrap_or(usize::MAX);
138 let dest = &mut buf.initialize_unfilled()[0..num_to_read];
139 let mut buf2 = ReadBuf::new(dest);
140 match Pin::new(&mut self.read_writer).poll_read(cx, &mut buf2) {
141 Poll::Ready(Ok(())) => {
142 let num_read = buf2.filled().len();
143 buf.advance(num_read);
144 self.remaining_bytes -= num_read as u64;
145 Poll::Ready(Ok(()))
146 }
147 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
148 Poll::Pending => Poll::Pending,
149 }
150 }
151}
152impl<RW: AsyncRead + AsyncWrite + Unpin> AsyncWrite for AsyncReadWriteTake<RW> {
153 fn poll_write(
154 mut self: Pin<&mut Self>,
155 cx: &mut Context<'_>,
156 buf: &[u8],
157 ) -> Poll<Result<usize, std::io::Error>> {
158 Pin::new(&mut self.read_writer).poll_write(cx, buf)
159 }
160
161 fn poll_flush(
162 mut self: Pin<&mut Self>,
163 cx: &mut Context<'_>,
164 ) -> Poll<Result<(), std::io::Error>> {
165 Pin::new(&mut self.read_writer).poll_flush(cx)
166 }
167
168 fn poll_shutdown(
169 mut self: Pin<&mut Self>,
170 cx: &mut Context<'_>,
171 ) -> Poll<Result<(), std::io::Error>> {
172 Pin::new(&mut self.read_writer).poll_shutdown(cx)
173 }
174}
175pub trait AsyncReadWriteExt: AsyncRead + AsyncWrite + Unpin {
176 /// Returns a struct that implements `tokio::io::AsyncRead` and `tokio::io::AsyncWrite`.
177 ///
178 /// It reads from `reader` until it is empty, then reads from `self`.
179 ///
180 /// It passes all writes through to `self`.
181 ///
182 /// This is like [`tokio::io::AsyncReadExt::chain`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.chain)
183 /// that also passes through writes.
184 fn chain_after<R: AsyncRead>(&mut self, reader: R) -> AsyncReadWriteChain<R, &mut Self> {
185 AsyncReadWriteChain::new(reader, self)
186 }
187
188 /// Wraps a struct that implements `tokio::io::AsyncRead` and `tokio::io::AsyncWrite`.
189 ///
190 /// The returned struct passes through reads and writes to the struct.
191 ///
192 /// It limits the number of bytes that can be read.
193 ///
194 /// This is like [`tokio::io::AsyncReadExt::take`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.take)
195 /// that also passes through writes.
196 fn take_rw(&mut self, len: u64) -> AsyncReadWriteTake<&mut Self> {
197 AsyncReadWriteTake::new(self, len)
198 }
199}
200impl<RW: AsyncRead + AsyncWrite + ?Sized + Unpin> AsyncReadWriteExt for RW {}