1use std::future::Future;
2use std::io::Result;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use aliasable::AliasableMut;
7use completion_core::CompletionFuture;
8use completion_io::{
9 AsyncBufRead, AsyncBufReadWith, AsyncRead, AsyncReadWith, ReadBuf, ReadBufMut,
10};
11use futures_core::ready;
12use pin_project_lite::pin_project;
13
14use super::extend_lifetime_mut;
15
16#[derive(Debug)]
18pub struct Chain<T, U> {
19 first: T,
20 second: U,
21 done_first: bool,
22}
23
24impl<T, U> Chain<T, U> {
25 pub(super) fn new(first: T, second: U) -> Self {
26 Self {
27 first,
28 second,
29 done_first: false,
30 }
31 }
32
33 pub fn into_inner(self) -> (T, U) {
35 (self.first, self.second)
36 }
37
38 pub fn get_ref(&self) -> (&T, &U) {
40 (&self.first, &self.second)
41 }
42
43 pub fn get_mut(&mut self) -> (&mut T, &mut U) {
48 (&mut self.first, &mut self.second)
49 }
50}
51
52impl<'a, T: AsyncRead, U: AsyncRead + 'static> AsyncReadWith<'a> for Chain<T, U> {
53 type ReadFuture = ReadChain<'a, T, U>;
54
55 fn read(&'a mut self, buf: ReadBufMut<'a>) -> Self::ReadFuture {
56 let state = if self.done_first {
57 ReadChainState::Second {
58 fut: self.second.read(buf),
59 }
60 } else {
61 let mut buf = AliasableMut::from_unique(unsafe { buf.into_mut() });
62 ReadChainState::First {
63 fut: self
64 .first
65 .read(unsafe { extend_lifetime_mut(&mut *buf) }.as_mut()),
66 second: &mut self.second,
67 initial_filled: buf.filled().len(),
68 buf,
69 done_first: &mut self.done_first,
70 }
71 };
72 ReadChain { state }
73 }
74}
75
76pin_project! {
77 pub struct ReadChain<'a, T: AsyncRead, U: AsyncRead>
79 where
80 U: 'static,
81 {
82 #[pin]
83 state: ReadChainState<'a, T, U>,
84 }
85}
86pin_project! {
87 #[project = ReadChainStateProj]
88 #[project_replace = ReadChainStateProjReplace]
89 enum ReadChainState<'a, T: AsyncRead, U: AsyncRead>
90 where
91 U: 'static,
92 {
93 First {
94 #[pin]
95 fut: <T as AsyncReadWith<'a>>::ReadFuture,
96 second: &'a mut U,
97 initial_filled: usize,
98 buf: AliasableMut<'a, ReadBuf<'a>>,
99 done_first: &'a mut bool,
100 },
101 Second {
102 #[pin]
103 fut: <U as AsyncReadWith<'a>>::ReadFuture,
104 },
105 Temporary,
106 }
107}
108
109impl<T: AsyncRead, U: AsyncRead> CompletionFuture for ReadChain<'_, T, U> {
110 type Output = Result<()>;
111
112 unsafe fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
113 let mut this = self.project();
114
115 if let ReadChainStateProj::First { fut, .. } = this.state.as_mut().project() {
116 ready!(fut.poll(cx))?;
117
118 let (second, initial_filled, buf, done_first) = match this
119 .state
120 .as_mut()
121 .project_replace(ReadChainState::Temporary)
122 {
123 ReadChainStateProjReplace::First {
124 second,
125 initial_filled,
126 buf,
127 done_first,
128 ..
129 } => (second, initial_filled, buf, done_first),
130 _ => unreachable!(),
131 };
132 let buf = AliasableMut::into_unique(buf).as_mut();
133
134 if buf.filled().len() > initial_filled || buf.capacity() - initial_filled == 0 {
135 return Poll::Ready(Ok(()));
136 }
137
138 *done_first = true;
139 this.state.set(ReadChainState::Second {
140 fut: second.read(buf),
141 });
142 }
143 match this.state.project() {
144 ReadChainStateProj::Second { fut } => fut.poll(cx),
145 ReadChainStateProj::Temporary => panic!("polled after completion"),
146 _ => unreachable!(),
147 }
148 }
149 unsafe fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
150 match self.project().state.project() {
151 ReadChainStateProj::First { fut, .. } => fut.poll_cancel(cx),
152 ReadChainStateProj::Second { fut } => fut.poll_cancel(cx),
153 _ => Poll::Ready(()),
154 }
155 }
156}
157
158impl<'a, T: AsyncRead, U: AsyncRead> Future for ReadChain<'_, T, U>
159where
160 <T as AsyncReadWith<'a>>::ReadFuture: Future<Output = Result<()>>,
161 <U as AsyncReadWith<'a>>::ReadFuture: Future<Output = Result<()>>,
162{
163 type Output = Result<()>;
164
165 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166 unsafe { CompletionFuture::poll(self, cx) }
167 }
168}
169
170impl<'a, T: AsyncBufRead, U: AsyncBufRead + 'static> AsyncBufReadWith<'a> for Chain<T, U> {
171 type FillBufFuture = FillBufChain<'a, T, U>;
172
173 fn fill_buf(&'a mut self) -> Self::FillBufFuture {
174 FillBufChain {
175 state: if self.done_first {
176 FillBufChainState::Second {
177 fut: self.second.fill_buf(),
178 }
179 } else {
180 FillBufChainState::First {
181 fut: self.first.fill_buf(),
182 done_first: &mut self.done_first,
183 second: &mut self.second,
184 }
185 },
186 }
187 }
188 fn consume(&mut self, amt: usize) {
189 if self.done_first {
190 self.second.consume(amt);
191 } else {
192 self.first.consume(amt);
193 }
194 }
195}
196
197pin_project! {
198 pub struct FillBufChain<'a, T: AsyncBufRead, U: AsyncBufRead>
200 where
201 U: 'static,
202 {
203 #[pin]
204 state: FillBufChainState<'a, T, U>,
205 }
206}
207pin_project! {
208 #[project = FillBufChainStateProj]
209 #[project_replace = FillBufChainStateProjReplace]
210 enum FillBufChainState<'a, T: AsyncBufRead, U: AsyncBufRead>
211 where
212 U: 'static,
213 {
214 First {
215 #[pin]
216 fut: <T as AsyncBufReadWith<'a>>::FillBufFuture,
217 done_first: &'a mut bool,
218 second: &'a mut U,
219 },
220 Second {
221 #[pin]
222 fut: <U as AsyncBufReadWith<'a>>::FillBufFuture,
223 },
224 Temporary,
225 }
226}
227
228impl<'a, T: AsyncBufRead, U: AsyncBufRead> CompletionFuture for FillBufChain<'a, T, U> {
229 type Output = Result<&'a [u8]>;
230
231 unsafe fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
232 let mut this = self.project();
233
234 if let FillBufChainStateProj::First { fut, .. } = this.state.as_mut().project() {
235 let buf = ready!(fut.poll(cx))?;
236
237 if !buf.is_empty() {
238 return Poll::Ready(Ok(buf));
239 }
240
241 let (done_first, second) = match this
242 .state
243 .as_mut()
244 .project_replace(FillBufChainState::Temporary)
245 {
246 FillBufChainStateProjReplace::First {
247 done_first, second, ..
248 } => (done_first, second),
249 _ => unreachable!(),
250 };
251
252 *done_first = true;
253 this.state.set(FillBufChainState::Second {
254 fut: second.fill_buf(),
255 });
256 }
257 match this.state.project() {
258 FillBufChainStateProj::Second { fut } => fut.poll(cx),
259 _ => unreachable!(),
260 }
261 }
262 unsafe fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
263 match self.project().state.project() {
264 FillBufChainStateProj::First { fut, .. } => fut.poll_cancel(cx),
265 FillBufChainStateProj::Second { fut } => fut.poll_cancel(cx),
266 _ => Poll::Ready(()),
267 }
268 }
269}
270impl<'a, T: AsyncBufRead, U: AsyncBufRead> Future for FillBufChain<'a, T, U>
271where
272 <T as AsyncBufReadWith<'a>>::FillBufFuture: Future<Output = Result<&'a [u8]>>,
273 <U as AsyncBufReadWith<'a>>::FillBufFuture: Future<Output = Result<&'a [u8]>>,
274{
275 type Output = Result<&'a [u8]>;
276 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
277 unsafe { CompletionFuture::poll(self, cx) }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 use std::io::{Error, ErrorKind};
286 use std::mem::MaybeUninit;
287
288 use crate::future::block_on;
289
290 use super::super::{test_utils::YieldingReader, AsyncReadExt};
291
292 #[test]
293 fn read() {
294 let first = YieldingReader::new(vec![Ok(&[1, 2, 3][..]), Ok(&[4])]);
295 let second = YieldingReader::new(vec![
296 Err(Error::new(ErrorKind::Other, "Some error")),
297 Ok(&[5, 6, 7][..]),
298 ]);
299
300 let mut storage = [MaybeUninit::uninit(); 20];
301 let mut buf = ReadBuf::uninit(&mut storage);
302
303 let mut chain = first.chain(second);
304
305 block_on(chain.read(buf.as_mut())).unwrap();
306 assert_eq!(buf.as_mut().filled(), [1, 2, 3]);
307
308 block_on(chain.read(buf.as_mut())).unwrap();
309 assert_eq!(buf.as_mut().filled(), [1, 2, 3, 4]);
310
311 assert_eq!(
312 block_on(chain.read(buf.as_mut())).unwrap_err().to_string(),
313 "Some error"
314 );
315 assert_eq!(buf.as_mut().filled(), [1, 2, 3, 4]);
316
317 block_on(chain.read(buf.as_mut())).unwrap();
318 assert_eq!(buf.as_mut().filled(), [1, 2, 3, 4, 5, 6, 7]);
319 }
320
321 #[test]
322 fn buf_read() {
323 let first = YieldingReader::new(vec![Ok(&[1, 2, 3][..]), Ok(&[4])]);
324 let second = YieldingReader::new(vec![
325 Err(Error::new(ErrorKind::Other, "Some error")),
326 Ok(&[5, 6, 7][..]),
327 ]);
328
329 let mut chain = first.chain(second);
330
331 assert_eq!(block_on(chain.fill_buf()).unwrap(), [1, 2, 3]);
332 assert_eq!(block_on(chain.fill_buf()).unwrap(), [1, 2, 3]);
333
334 chain.consume(2);
335 assert_eq!(block_on(chain.fill_buf()).unwrap(), [3]);
336 assert_eq!(block_on(chain.fill_buf()).unwrap(), [3]);
337
338 chain.consume(1);
339 assert_eq!(block_on(chain.fill_buf()).unwrap(), [4]);
340 assert_eq!(block_on(chain.fill_buf()).unwrap(), [4]);
341
342 chain.consume(1);
343 assert_eq!(
344 block_on(chain.fill_buf()).unwrap_err().to_string(),
345 "Some error"
346 );
347 assert_eq!(block_on(chain.fill_buf()).unwrap(), [5, 6, 7]);
348 assert_eq!(block_on(chain.fill_buf()).unwrap(), [5, 6, 7]);
349
350 chain.consume(3);
351 assert_eq!(block_on(chain.fill_buf()).unwrap(), []);
352 assert_eq!(block_on(chain.fill_buf()).unwrap(), []);
353 }
354}