http_proxy_client_async/
prepend_io_stream.rs

1use futures_io::{AsyncRead, AsyncWrite, IoSlice, IoSliceMut};
2use futures_util::io::{AsyncReadExt, Chain, Cursor};
3use std::io::Result;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7#[derive(Debug)]
8pub enum PrependIoStream<T>
9where
10    T: AsyncRead + AsyncWrite + Unpin,
11{
12    Chain(Chain<Cursor<Vec<u8>>, T>),
13    Plain(T),
14}
15
16impl<T> PrependIoStream<T>
17where
18    T: AsyncRead + AsyncWrite + Unpin,
19{
20    pub fn from_vec(stream: T, read_prepend: Option<Vec<u8>>) -> Self {
21        let read_prepend = match read_prepend {
22            None => None,
23            Some(ref boxed_buf) if boxed_buf.is_empty() => None,
24            Some(boxed_buf) => Some(boxed_buf),
25        };
26        match read_prepend {
27            Some(read_prepend) => Self::from_cursor(stream, Cursor::new(read_prepend)),
28            None => Self::plain(stream),
29        }
30    }
31
32    pub fn from_cursor(stream: T, read_prepend: Cursor<Vec<u8>>) -> Self {
33        Self::chain(read_prepend.chain(stream))
34    }
35
36    pub fn chain(chain: Chain<Cursor<Vec<u8>>, T>) -> Self {
37        PrependIoStream::Chain(chain)
38    }
39
40    pub fn plain(stream: T) -> Self {
41        PrependIoStream::Plain(stream)
42    }
43
44    pub fn into_inner(self) -> (T, Option<Cursor<Vec<u8>>>) {
45        match self {
46            PrependIoStream::Chain(chain) => {
47                let (cursor, stream) = chain.into_inner();
48                (stream, Some(cursor))
49            }
50            PrependIoStream::Plain(stream) => (stream, None),
51        }
52    }
53
54    pub fn pending_prepend_data(&self) -> &[u8] {
55        match self {
56            PrependIoStream::Chain(chain) => {
57                let (cursor, _) = chain.get_ref();
58                let pos = cursor.position() as usize;
59                let vec = cursor.get_ref();
60                &vec[pos..]
61            }
62            PrependIoStream::Plain(_) => &[],
63        }
64    }
65}
66
67impl<T> AsyncRead for PrependIoStream<T>
68where
69    T: AsyncRead + AsyncWrite + Unpin,
70{
71    fn poll_read(
72        self: Pin<&mut Self>,
73        cx: &mut Context<'_>,
74        buf: &mut [u8],
75    ) -> Poll<Result<usize>> {
76        match self.get_mut() {
77            PrependIoStream::Plain(ref mut stream) => {
78                AsyncRead::poll_read(Pin::new(stream), cx, buf)
79            }
80            PrependIoStream::Chain(ref mut chain) => AsyncRead::poll_read(Pin::new(chain), cx, buf),
81        }
82    }
83
84    fn poll_read_vectored(
85        self: Pin<&mut Self>,
86        cx: &mut Context<'_>,
87        bufs: &mut [IoSliceMut<'_>],
88    ) -> Poll<Result<usize>> {
89        match self.get_mut() {
90            PrependIoStream::Plain(ref mut stream) => {
91                AsyncRead::poll_read_vectored(Pin::new(stream), cx, bufs)
92            }
93            PrependIoStream::Chain(ref mut chain) => {
94                AsyncRead::poll_read_vectored(Pin::new(chain), cx, bufs)
95            }
96        }
97    }
98}
99
100impl<T> AsyncWrite for PrependIoStream<T>
101where
102    T: AsyncRead + AsyncWrite + Unpin,
103{
104    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
105        match self.get_mut() {
106            PrependIoStream::Plain(ref mut stream) => {
107                AsyncWrite::poll_write(Pin::new(stream), cx, buf)
108            }
109            PrependIoStream::Chain(chain) => {
110                let (_, stream) = chain.get_mut();
111                AsyncWrite::poll_write(Pin::new(stream), cx, buf)
112            }
113        }
114    }
115
116    fn poll_write_vectored(
117        self: Pin<&mut Self>,
118        cx: &mut Context<'_>,
119        bufs: &[IoSlice<'_>],
120    ) -> Poll<Result<usize>> {
121        match self.get_mut() {
122            PrependIoStream::Plain(ref mut stream) => {
123                AsyncWrite::poll_write_vectored(Pin::new(stream), cx, bufs)
124            }
125            PrependIoStream::Chain(chain) => {
126                let (_, stream) = chain.get_mut();
127                AsyncWrite::poll_write_vectored(Pin::new(stream), cx, bufs)
128            }
129        }
130    }
131
132    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
133        match self.get_mut() {
134            PrependIoStream::Plain(ref mut stream) => AsyncWrite::poll_flush(Pin::new(stream), cx),
135            PrependIoStream::Chain(chain) => {
136                let (_, stream) = chain.get_mut();
137                AsyncWrite::poll_flush(Pin::new(stream), cx)
138            }
139        }
140    }
141
142    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
143        match self.get_mut() {
144            PrependIoStream::Plain(ref mut stream) => AsyncWrite::poll_close(Pin::new(stream), cx),
145            PrependIoStream::Chain(chain) => {
146                let (_, stream) = chain.get_mut();
147                AsyncWrite::poll_close(Pin::new(stream), cx)
148            }
149        }
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use futures::executor;
157    use merge_io::MergeIO;
158
159    #[test]
160    fn simple_prepended_read_test() -> Result<()> {
161        executor::block_on(async {
162            let reader = Cursor::new(vec![1, 2, 3, 4]);
163            let writer = Cursor::new(vec![0u8; 1024]);
164            let stream = MergeIO::new(reader, writer);
165
166            let mut stream = PrependIoStream::from_vec(stream, Some(vec![50, 60, 70, 80]));
167
168            let mut buf = vec![];
169            stream.read_to_end(&mut buf).await?;
170
171            assert_eq!(buf.as_slice(), &[50, 60, 70, 80, 1, 2, 3, 4]);
172
173            Ok(())
174        })
175    }
176
177    #[test]
178    fn small_buffer_prepended_read_test() -> Result<()> {
179        executor::block_on(async {
180            let reader = Cursor::new(vec![1, 2, 3, 4]);
181            let writer = Cursor::new(vec![0u8; 1024]);
182            let stream = MergeIO::new(reader, writer);
183
184            let mut stream = PrependIoStream::from_vec(stream, Some(vec![50, 60, 70, 80]));
185
186            // Expect to properly read prepend buf that's incomplete.
187            let mut buf = [0u8; 2];
188            let n = stream.read(&mut buf).await?;
189            assert_eq!(n, 2);
190            assert_eq!(&buf[..n], &[50, 60]);
191
192            // Expect to properly read data up to the prepended buf end.
193            let mut buf = [0u8; 1024];
194            let n = stream.read(&mut buf).await?;
195            assert_eq!(n, 2);
196            assert_eq!(&buf[..n], &[70, 80]);
197
198            // Expect to read data normally from the wrapped stream.
199            let mut buf = [0u8; 1024];
200            let n = stream.read(&mut buf).await?;
201            assert_eq!(n, 4);
202            assert_eq!(&buf[..n], &[1, 2, 3, 4]);
203
204            Ok(())
205        })
206    }
207}