1use futures_io::{AsyncRead, AsyncWrite, IoSlice, IoSliceMut};
2use futures_util::io::{AsyncReadExt, Chain, Cursor};
3use std::io::Result;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7#[derive(Debug)]
8pub enum PrependIoStream<T>
9where
10 T: AsyncRead + AsyncWrite + Unpin,
11{
12 Chain(Chain<Cursor<Vec<u8>>, T>),
13 Plain(T),
14}
15
16impl<T> PrependIoStream<T>
17where
18 T: AsyncRead + AsyncWrite + Unpin,
19{
20 pub fn from_vec(stream: T, read_prepend: Option<Vec<u8>>) -> Self {
21 let read_prepend = match read_prepend {
22 None => None,
23 Some(ref boxed_buf) if boxed_buf.is_empty() => None,
24 Some(boxed_buf) => Some(boxed_buf),
25 };
26 match read_prepend {
27 Some(read_prepend) => Self::from_cursor(stream, Cursor::new(read_prepend)),
28 None => Self::plain(stream),
29 }
30 }
31
32 pub fn from_cursor(stream: T, read_prepend: Cursor<Vec<u8>>) -> Self {
33 Self::chain(read_prepend.chain(stream))
34 }
35
36 pub fn chain(chain: Chain<Cursor<Vec<u8>>, T>) -> Self {
37 PrependIoStream::Chain(chain)
38 }
39
40 pub fn plain(stream: T) -> Self {
41 PrependIoStream::Plain(stream)
42 }
43
44 pub fn into_inner(self) -> (T, Option<Cursor<Vec<u8>>>) {
45 match self {
46 PrependIoStream::Chain(chain) => {
47 let (cursor, stream) = chain.into_inner();
48 (stream, Some(cursor))
49 }
50 PrependIoStream::Plain(stream) => (stream, None),
51 }
52 }
53
54 pub fn pending_prepend_data(&self) -> &[u8] {
55 match self {
56 PrependIoStream::Chain(chain) => {
57 let (cursor, _) = chain.get_ref();
58 let pos = cursor.position() as usize;
59 let vec = cursor.get_ref();
60 &vec[pos..]
61 }
62 PrependIoStream::Plain(_) => &[],
63 }
64 }
65}
66
67impl<T> AsyncRead for PrependIoStream<T>
68where
69 T: AsyncRead + AsyncWrite + Unpin,
70{
71 fn poll_read(
72 self: Pin<&mut Self>,
73 cx: &mut Context<'_>,
74 buf: &mut [u8],
75 ) -> Poll<Result<usize>> {
76 match self.get_mut() {
77 PrependIoStream::Plain(ref mut stream) => {
78 AsyncRead::poll_read(Pin::new(stream), cx, buf)
79 }
80 PrependIoStream::Chain(ref mut chain) => AsyncRead::poll_read(Pin::new(chain), cx, buf),
81 }
82 }
83
84 fn poll_read_vectored(
85 self: Pin<&mut Self>,
86 cx: &mut Context<'_>,
87 bufs: &mut [IoSliceMut<'_>],
88 ) -> Poll<Result<usize>> {
89 match self.get_mut() {
90 PrependIoStream::Plain(ref mut stream) => {
91 AsyncRead::poll_read_vectored(Pin::new(stream), cx, bufs)
92 }
93 PrependIoStream::Chain(ref mut chain) => {
94 AsyncRead::poll_read_vectored(Pin::new(chain), cx, bufs)
95 }
96 }
97 }
98}
99
100impl<T> AsyncWrite for PrependIoStream<T>
101where
102 T: AsyncRead + AsyncWrite + Unpin,
103{
104 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
105 match self.get_mut() {
106 PrependIoStream::Plain(ref mut stream) => {
107 AsyncWrite::poll_write(Pin::new(stream), cx, buf)
108 }
109 PrependIoStream::Chain(chain) => {
110 let (_, stream) = chain.get_mut();
111 AsyncWrite::poll_write(Pin::new(stream), cx, buf)
112 }
113 }
114 }
115
116 fn poll_write_vectored(
117 self: Pin<&mut Self>,
118 cx: &mut Context<'_>,
119 bufs: &[IoSlice<'_>],
120 ) -> Poll<Result<usize>> {
121 match self.get_mut() {
122 PrependIoStream::Plain(ref mut stream) => {
123 AsyncWrite::poll_write_vectored(Pin::new(stream), cx, bufs)
124 }
125 PrependIoStream::Chain(chain) => {
126 let (_, stream) = chain.get_mut();
127 AsyncWrite::poll_write_vectored(Pin::new(stream), cx, bufs)
128 }
129 }
130 }
131
132 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
133 match self.get_mut() {
134 PrependIoStream::Plain(ref mut stream) => AsyncWrite::poll_flush(Pin::new(stream), cx),
135 PrependIoStream::Chain(chain) => {
136 let (_, stream) = chain.get_mut();
137 AsyncWrite::poll_flush(Pin::new(stream), cx)
138 }
139 }
140 }
141
142 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
143 match self.get_mut() {
144 PrependIoStream::Plain(ref mut stream) => AsyncWrite::poll_close(Pin::new(stream), cx),
145 PrependIoStream::Chain(chain) => {
146 let (_, stream) = chain.get_mut();
147 AsyncWrite::poll_close(Pin::new(stream), cx)
148 }
149 }
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use futures::executor;
157 use merge_io::MergeIO;
158
159 #[test]
160 fn simple_prepended_read_test() -> Result<()> {
161 executor::block_on(async {
162 let reader = Cursor::new(vec![1, 2, 3, 4]);
163 let writer = Cursor::new(vec![0u8; 1024]);
164 let stream = MergeIO::new(reader, writer);
165
166 let mut stream = PrependIoStream::from_vec(stream, Some(vec![50, 60, 70, 80]));
167
168 let mut buf = vec![];
169 stream.read_to_end(&mut buf).await?;
170
171 assert_eq!(buf.as_slice(), &[50, 60, 70, 80, 1, 2, 3, 4]);
172
173 Ok(())
174 })
175 }
176
177 #[test]
178 fn small_buffer_prepended_read_test() -> Result<()> {
179 executor::block_on(async {
180 let reader = Cursor::new(vec![1, 2, 3, 4]);
181 let writer = Cursor::new(vec![0u8; 1024]);
182 let stream = MergeIO::new(reader, writer);
183
184 let mut stream = PrependIoStream::from_vec(stream, Some(vec![50, 60, 70, 80]));
185
186 let mut buf = [0u8; 2];
188 let n = stream.read(&mut buf).await?;
189 assert_eq!(n, 2);
190 assert_eq!(&buf[..n], &[50, 60]);
191
192 let mut buf = [0u8; 1024];
194 let n = stream.read(&mut buf).await?;
195 assert_eq!(n, 2);
196 assert_eq!(&buf[..n], &[70, 80]);
197
198 let mut buf = [0u8; 1024];
200 let n = stream.read(&mut buf).await?;
201 assert_eq!(n, 4);
202 assert_eq!(&buf[..n], &[1, 2, 3, 4]);
203
204 Ok(())
205 })
206 }
207}