fixed_buffer_tokio/
async_read_write_chain.rs1#![forbid(unsafe_code)]
2
3use std::task::Context;
4use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
5use tokio::macros::support::{Pin, Poll};
6
7pub struct AsyncReadWriteChain<
24 'a,
25 R: tokio::io::AsyncRead + Send + Unpin,
26 RW: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin,
27> {
28 reader: Option<&'a mut R>,
29 read_writer: &'a mut RW,
30}
31
32impl<'a, R: AsyncRead + Send + Unpin, RW: AsyncRead + AsyncWrite + Send + Unpin>
33 AsyncReadWriteChain<'a, R, RW>
34{
35 pub fn new(reader: &'a mut R, read_writer: &'a mut RW) -> AsyncReadWriteChain<'a, R, RW> {
37 Self {
38 reader: Some(reader),
39 read_writer,
40 }
41 }
42}
43
44impl<'a, R: AsyncRead + Send + Unpin, RW: AsyncRead + AsyncWrite + Send + Unpin> AsyncRead
45 for AsyncReadWriteChain<'a, R, RW>
46{
47 fn poll_read(
48 self: Pin<&mut Self>,
49 cx: &mut Context<'_>,
50 buf: &mut ReadBuf<'_>,
51 ) -> Poll<Result<(), std::io::Error>> {
52 let mut_self = self.get_mut();
53 if let Some(ref mut reader) = mut_self.reader {
54 let before_len = buf.filled().len();
55 match Pin::new(&mut *reader).poll_read(cx, buf) {
56 Poll::Pending => return Poll::Pending,
57 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
58 Poll::Ready(Ok(())) => {
59 let num_read = buf.filled().len() - before_len;
60 if num_read > 0 {
61 return Poll::Ready(Ok(()));
62 } else {
63 mut_self.reader = None;
65 }
67 }
68 }
69 }
70 Pin::new(&mut mut_self.read_writer).poll_read(cx, buf)
71 }
72}
73
74impl<'a, R: AsyncRead + Send + Unpin, RW: AsyncRead + AsyncWrite + Send + Unpin> AsyncWrite
75 for AsyncReadWriteChain<'a, R, RW>
76{
77 fn poll_write(
78 self: Pin<&mut Self>,
79 cx: &mut Context<'_>,
80 buf: &[u8],
81 ) -> Poll<Result<usize, std::io::Error>> {
82 let mut_self = self.get_mut();
83 Pin::new(&mut mut_self.read_writer).poll_write(cx, buf)
84 }
85
86 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
87 let mut_self = self.get_mut();
88 Pin::new(&mut mut_self.read_writer).poll_flush(cx)
89 }
90
91 fn poll_shutdown(
92 self: Pin<&mut Self>,
93 cx: &mut Context<'_>,
94 ) -> Poll<Result<(), std::io::Error>> {
95 let mut_self = self.get_mut();
96 Pin::new(&mut mut_self.read_writer).poll_shutdown(cx)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::super::*;
103 use fixed_buffer::escape_ascii;
104
105 #[tokio::test]
106 async fn both_empty() {
107 let mut reader = std::io::Cursor::new(b"");
108 let mut read_writer: AsyncFixedBuf<8> = AsyncFixedBuf::new();
109 let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
110 let mut buf = [b'.'; 8];
111 assert_eq!(
112 0,
113 tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
114 .await
115 .unwrap()
116 );
117 assert_eq!("........", escape_ascii(&buf));
118 }
119
120 #[tokio::test]
121 async fn doesnt_read_second_when_first_has_data() {
122 let mut reader = std::io::Cursor::new(b"abc");
123 let mut read_writer = FakeAsyncReadWriter::empty();
124 let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
125 let mut buf = [b'.'; 4];
126 assert_eq!(
127 3,
128 tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
129 .await
130 .unwrap()
131 );
132 assert_eq!("abc.", escape_ascii(&buf));
133 }
134
135 #[tokio::test]
136 async fn doesnt_read_second_when_first_returns_error() {
137 let mut reader = FakeAsyncReadWriter::new(vec![Err(err1()), Err(err1())]);
138 let mut read_writer = FakeAsyncReadWriter::empty();
139 let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
140 let mut buf = [b'.'; 4];
141 let err = tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
142 .await
143 .unwrap_err();
144 assert_eq!(std::io::ErrorKind::Other, err.kind());
145 assert_eq!("err1", err.to_string());
146 assert_eq!("....", escape_ascii(&buf));
147 tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
148 .await
149 .unwrap_err();
150 }
151
152 #[tokio::test]
153 async fn reads_second_when_first_empty() {
154 let mut reader = std::io::Cursor::new(b"");
155 let mut read_writer: AsyncFixedBuf<4> = AsyncFixedBuf::new();
156 read_writer.write_str("abc").unwrap();
157 let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
158 let mut buf = [b'.'; 4];
159 assert_eq!(
160 3,
161 tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
162 .await
163 .unwrap()
164 );
165 assert_eq!("abc.", escape_ascii(&buf));
166 }
167
168 #[tokio::test]
169 async fn reads_first_then_second() {
170 let mut reader = std::io::Cursor::new(b"ab");
171 let mut read_writer: AsyncFixedBuf<4> = AsyncFixedBuf::new();
172 read_writer.write_str("cd").unwrap();
173 let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
174 let mut buf = [b'.'; 4];
175 assert_eq!(
176 2,
177 tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
178 .await
179 .unwrap()
180 );
181 assert_eq!("ab..", escape_ascii(&buf));
182 assert_eq!(
183 2,
184 tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
185 .await
186 .unwrap()
187 );
188 assert_eq!("cd..", escape_ascii(&buf));
189 }
190
191 #[tokio::test]
192 async fn returns_error_from_second() {
193 let mut reader = std::io::Cursor::new(b"");
194 let mut read_writer = FakeAsyncReadWriter::new(vec![Err(err1()), Err(err1())]);
195 let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
196 let mut buf = [b'.'; 4];
197 let err = tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
198 .await
199 .unwrap_err();
200 assert_eq!(std::io::ErrorKind::Other, err.kind());
201 assert_eq!("err1", err.to_string());
202 assert_eq!("....", escape_ascii(&buf));
203 tokio::io::AsyncReadExt::read(&mut chain, &mut buf)
204 .await
205 .unwrap_err();
206 }
207
208 #[tokio::test]
209 async fn passes_writes_through() {
210 let mut reader = std::io::Cursor::new(b"");
211 let mut read_writer: AsyncFixedBuf<4> = AsyncFixedBuf::new();
212 let mut chain = AsyncReadWriteChain::new(&mut reader, &mut read_writer);
213 assert_eq!(
214 3,
215 tokio::io::AsyncWriteExt::write(&mut chain, b"abc")
216 .await
217 .unwrap()
218 );
219 assert_eq!("abc", read_writer.escape_ascii());
220 }
221}