1use bytes::Bytes;
2use pin_project_lite::pin_project;
3use std::{
4 fmt,
5 io::Cursor,
6 pin::Pin,
7 task::{Context, Poll, ready},
8};
9use tokio::io::{self, AsyncBufRead, AsyncRead, ReadBuf};
10
11pin_project! {
12 #[derive(Debug, Clone)]
14 pub struct HeapReader {
15 #[pin]
16 inner: Cursor<Vec<u8>>,
17 }
18}
19
20impl HeapReader {
21 pub const fn new(data: Vec<u8>) -> Self {
23 Self {
24 inner: Cursor::new(data),
25 }
26 }
27}
28
29impl From<Vec<u8>> for HeapReader {
30 fn from(data: Vec<u8>) -> Self {
31 Self::new(data)
32 }
33}
34
35impl From<&[u8]> for HeapReader {
36 fn from(data: &[u8]) -> Self {
37 Self::new(data.to_vec())
38 }
39}
40
41impl From<&str> for HeapReader {
42 fn from(data: &str) -> Self {
43 Self::new(data.as_bytes().to_vec())
44 }
45}
46
47impl Default for HeapReader {
48 fn default() -> Self {
49 Self::new(Vec::new())
50 }
51}
52
53impl From<Bytes> for HeapReader {
54 fn from(data: Bytes) -> Self {
55 Self::new(data.to_vec())
56 }
57}
58
59impl AsyncRead for HeapReader {
60 fn poll_read(
61 self: Pin<&mut Self>,
62 cx: &mut Context<'_>,
63 buf: &mut ReadBuf<'_>,
64 ) -> Poll<io::Result<()>> {
65 self.project().inner.poll_read(cx, buf)
66 }
67}
68
69pin_project! {
70 #[must_use = "streams do nothing unless polled"]
72 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
73 pub struct ChainReader<T, U> {
74 #[pin]
75 first: T,
76 #[pin]
77 second: U,
78 done_first: bool,
79 }
80}
81
82impl<T, U> ChainReader<T, U>
83where
84 T: AsyncRead,
85 U: AsyncRead,
86{
87 pub const fn new(first: T, second: U) -> Self {
89 Self {
90 first,
91 second,
92 done_first: false,
93 }
94 }
95
96 pub fn get_ref(&self) -> (&T, &U) {
98 (&self.first, &self.second)
99 }
100
101 pub fn get_mut(&mut self) -> (&mut T, &mut U) {
107 (&mut self.first, &mut self.second)
108 }
109
110 pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
116 let me = self.project();
117 (me.first, me.second)
118 }
119
120 pub fn into_inner(self) -> (T, U) {
122 (self.first, self.second)
123 }
124}
125
126impl<T, U> fmt::Debug for ChainReader<T, U>
127where
128 T: fmt::Debug,
129 U: fmt::Debug,
130{
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 f.debug_struct("ChainReader")
133 .field("t", &self.first)
134 .field("u", &self.second)
135 .finish()
136 }
137}
138
139impl<T, U> AsyncRead for ChainReader<T, U>
140where
141 T: AsyncRead,
142 U: AsyncRead,
143{
144 fn poll_read(
145 self: Pin<&mut Self>,
146 cx: &mut Context<'_>,
147 buf: &mut ReadBuf<'_>,
148 ) -> Poll<io::Result<()>> {
149 let me = self.project();
150
151 if !*me.done_first {
152 let rem = buf.remaining();
153 ready!(me.first.poll_read(cx, buf))?;
154 if buf.remaining() == rem {
155 *me.done_first = true;
156 } else {
157 return Poll::Ready(Ok(()));
158 }
159 }
160 me.second.poll_read(cx, buf)
161 }
162}
163
164impl<T, U> AsyncBufRead for ChainReader<T, U>
165where
166 T: AsyncBufRead,
167 U: AsyncBufRead,
168{
169 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
170 let me = self.project();
171
172 if !*me.done_first {
173 match ready!(me.first.poll_fill_buf(cx)?) {
174 [] => {
175 *me.done_first = true;
176 }
177 buf => return Poll::Ready(Ok(buf)),
178 }
179 }
180 me.second.poll_fill_buf(cx)
181 }
182
183 fn consume(self: Pin<&mut Self>, amt: usize) {
184 let me = self.project();
185 if !*me.done_first {
186 me.first.consume(amt)
187 } else {
188 me.second.consume(amt)
189 }
190 }
191}