completion/io/read/
chain.rs

1use std::future::Future;
2use std::io::Result;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use aliasable::AliasableMut;
7use completion_core::CompletionFuture;
8use completion_io::{
9    AsyncBufRead, AsyncBufReadWith, AsyncRead, AsyncReadWith, ReadBuf, ReadBufMut,
10};
11use futures_core::ready;
12use pin_project_lite::pin_project;
13
14use super::extend_lifetime_mut;
15
16/// Reader for [`AsyncReadExt::chain`](super::AsyncReadExt::chain).
17#[derive(Debug)]
18pub struct Chain<T, U> {
19    first: T,
20    second: U,
21    done_first: bool,
22}
23
24impl<T, U> Chain<T, U> {
25    pub(super) fn new(first: T, second: U) -> Self {
26        Self {
27            first,
28            second,
29            done_first: false,
30        }
31    }
32
33    /// Consume the chain, returning the wrapped readers.
34    pub fn into_inner(self) -> (T, U) {
35        (self.first, self.second)
36    }
37
38    /// Get shared references to the underlying readers in the chain.
39    pub fn get_ref(&self) -> (&T, &U) {
40        (&self.first, &self.second)
41    }
42
43    /// Get mutable references to the underlying readers in the chain.
44    ///
45    /// Care should be taken to avoid modifying the internal I/O state of the underlying readers as
46    /// doing so may corrupt the internal state of this chain.
47    pub fn get_mut(&mut self) -> (&mut T, &mut U) {
48        (&mut self.first, &mut self.second)
49    }
50}
51
52impl<'a, T: AsyncRead, U: AsyncRead + 'static> AsyncReadWith<'a> for Chain<T, U> {
53    type ReadFuture = ReadChain<'a, T, U>;
54
55    fn read(&'a mut self, buf: ReadBufMut<'a>) -> Self::ReadFuture {
56        let state = if self.done_first {
57            ReadChainState::Second {
58                fut: self.second.read(buf),
59            }
60        } else {
61            let mut buf = AliasableMut::from_unique(unsafe { buf.into_mut() });
62            ReadChainState::First {
63                fut: self
64                    .first
65                    .read(unsafe { extend_lifetime_mut(&mut *buf) }.as_mut()),
66                second: &mut self.second,
67                initial_filled: buf.filled().len(),
68                buf,
69                done_first: &mut self.done_first,
70            }
71        };
72        ReadChain { state }
73    }
74}
75
76pin_project! {
77    /// Future for [`read`](AsyncReadWith::read) on a [`Chain`].
78    pub struct ReadChain<'a, T: AsyncRead, U: AsyncRead>
79    where
80        U: 'static,
81    {
82        #[pin]
83        state: ReadChainState<'a, T, U>,
84    }
85}
86pin_project! {
87    #[project = ReadChainStateProj]
88    #[project_replace = ReadChainStateProjReplace]
89    enum ReadChainState<'a, T: AsyncRead, U: AsyncRead>
90    where
91        U: 'static,
92    {
93        First {
94            #[pin]
95            fut: <T as AsyncReadWith<'a>>::ReadFuture,
96            second: &'a mut U,
97            initial_filled: usize,
98            buf: AliasableMut<'a, ReadBuf<'a>>,
99            done_first: &'a mut bool,
100        },
101        Second {
102            #[pin]
103            fut: <U as AsyncReadWith<'a>>::ReadFuture,
104        },
105        Temporary,
106    }
107}
108
109impl<T: AsyncRead, U: AsyncRead> CompletionFuture for ReadChain<'_, T, U> {
110    type Output = Result<()>;
111
112    unsafe fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
113        let mut this = self.project();
114
115        if let ReadChainStateProj::First { fut, .. } = this.state.as_mut().project() {
116            ready!(fut.poll(cx))?;
117
118            let (second, initial_filled, buf, done_first) = match this
119                .state
120                .as_mut()
121                .project_replace(ReadChainState::Temporary)
122            {
123                ReadChainStateProjReplace::First {
124                    second,
125                    initial_filled,
126                    buf,
127                    done_first,
128                    ..
129                } => (second, initial_filled, buf, done_first),
130                _ => unreachable!(),
131            };
132            let buf = AliasableMut::into_unique(buf).as_mut();
133
134            if buf.filled().len() > initial_filled || buf.capacity() - initial_filled == 0 {
135                return Poll::Ready(Ok(()));
136            }
137
138            *done_first = true;
139            this.state.set(ReadChainState::Second {
140                fut: second.read(buf),
141            });
142        }
143        match this.state.project() {
144            ReadChainStateProj::Second { fut } => fut.poll(cx),
145            ReadChainStateProj::Temporary => panic!("polled after completion"),
146            _ => unreachable!(),
147        }
148    }
149    unsafe fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
150        match self.project().state.project() {
151            ReadChainStateProj::First { fut, .. } => fut.poll_cancel(cx),
152            ReadChainStateProj::Second { fut } => fut.poll_cancel(cx),
153            _ => Poll::Ready(()),
154        }
155    }
156}
157
158impl<'a, T: AsyncRead, U: AsyncRead> Future for ReadChain<'_, T, U>
159where
160    <T as AsyncReadWith<'a>>::ReadFuture: Future<Output = Result<()>>,
161    <U as AsyncReadWith<'a>>::ReadFuture: Future<Output = Result<()>>,
162{
163    type Output = Result<()>;
164
165    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166        unsafe { CompletionFuture::poll(self, cx) }
167    }
168}
169
170impl<'a, T: AsyncBufRead, U: AsyncBufRead + 'static> AsyncBufReadWith<'a> for Chain<T, U> {
171    type FillBufFuture = FillBufChain<'a, T, U>;
172
173    fn fill_buf(&'a mut self) -> Self::FillBufFuture {
174        FillBufChain {
175            state: if self.done_first {
176                FillBufChainState::Second {
177                    fut: self.second.fill_buf(),
178                }
179            } else {
180                FillBufChainState::First {
181                    fut: self.first.fill_buf(),
182                    done_first: &mut self.done_first,
183                    second: &mut self.second,
184                }
185            },
186        }
187    }
188    fn consume(&mut self, amt: usize) {
189        if self.done_first {
190            self.second.consume(amt);
191        } else {
192            self.first.consume(amt);
193        }
194    }
195}
196
197pin_project! {
198    /// Future for [`fill_buf`](AsyncBufReadWith::fill_buf) on a [`Chain`].
199    pub struct FillBufChain<'a, T: AsyncBufRead, U: AsyncBufRead>
200    where
201        U: 'static,
202    {
203        #[pin]
204        state: FillBufChainState<'a, T, U>,
205    }
206}
207pin_project! {
208    #[project = FillBufChainStateProj]
209    #[project_replace = FillBufChainStateProjReplace]
210    enum FillBufChainState<'a, T: AsyncBufRead, U: AsyncBufRead>
211    where
212        U: 'static,
213    {
214        First {
215            #[pin]
216            fut: <T as AsyncBufReadWith<'a>>::FillBufFuture,
217            done_first: &'a mut bool,
218            second: &'a mut U,
219        },
220        Second {
221            #[pin]
222            fut: <U as AsyncBufReadWith<'a>>::FillBufFuture,
223        },
224        Temporary,
225    }
226}
227
228impl<'a, T: AsyncBufRead, U: AsyncBufRead> CompletionFuture for FillBufChain<'a, T, U> {
229    type Output = Result<&'a [u8]>;
230
231    unsafe fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
232        let mut this = self.project();
233
234        if let FillBufChainStateProj::First { fut, .. } = this.state.as_mut().project() {
235            let buf = ready!(fut.poll(cx))?;
236
237            if !buf.is_empty() {
238                return Poll::Ready(Ok(buf));
239            }
240
241            let (done_first, second) = match this
242                .state
243                .as_mut()
244                .project_replace(FillBufChainState::Temporary)
245            {
246                FillBufChainStateProjReplace::First {
247                    done_first, second, ..
248                } => (done_first, second),
249                _ => unreachable!(),
250            };
251
252            *done_first = true;
253            this.state.set(FillBufChainState::Second {
254                fut: second.fill_buf(),
255            });
256        }
257        match this.state.project() {
258            FillBufChainStateProj::Second { fut } => fut.poll(cx),
259            _ => unreachable!(),
260        }
261    }
262    unsafe fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
263        match self.project().state.project() {
264            FillBufChainStateProj::First { fut, .. } => fut.poll_cancel(cx),
265            FillBufChainStateProj::Second { fut } => fut.poll_cancel(cx),
266            _ => Poll::Ready(()),
267        }
268    }
269}
270impl<'a, T: AsyncBufRead, U: AsyncBufRead> Future for FillBufChain<'a, T, U>
271where
272    <T as AsyncBufReadWith<'a>>::FillBufFuture: Future<Output = Result<&'a [u8]>>,
273    <U as AsyncBufReadWith<'a>>::FillBufFuture: Future<Output = Result<&'a [u8]>>,
274{
275    type Output = Result<&'a [u8]>;
276    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
277        unsafe { CompletionFuture::poll(self, cx) }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    use std::io::{Error, ErrorKind};
286    use std::mem::MaybeUninit;
287
288    use crate::future::block_on;
289
290    use super::super::{test_utils::YieldingReader, AsyncReadExt};
291
292    #[test]
293    fn read() {
294        let first = YieldingReader::new(vec![Ok(&[1, 2, 3][..]), Ok(&[4])]);
295        let second = YieldingReader::new(vec![
296            Err(Error::new(ErrorKind::Other, "Some error")),
297            Ok(&[5, 6, 7][..]),
298        ]);
299
300        let mut storage = [MaybeUninit::uninit(); 20];
301        let mut buf = ReadBuf::uninit(&mut storage);
302
303        let mut chain = first.chain(second);
304
305        block_on(chain.read(buf.as_mut())).unwrap();
306        assert_eq!(buf.as_mut().filled(), [1, 2, 3]);
307
308        block_on(chain.read(buf.as_mut())).unwrap();
309        assert_eq!(buf.as_mut().filled(), [1, 2, 3, 4]);
310
311        assert_eq!(
312            block_on(chain.read(buf.as_mut())).unwrap_err().to_string(),
313            "Some error"
314        );
315        assert_eq!(buf.as_mut().filled(), [1, 2, 3, 4]);
316
317        block_on(chain.read(buf.as_mut())).unwrap();
318        assert_eq!(buf.as_mut().filled(), [1, 2, 3, 4, 5, 6, 7]);
319    }
320
321    #[test]
322    fn buf_read() {
323        let first = YieldingReader::new(vec![Ok(&[1, 2, 3][..]), Ok(&[4])]);
324        let second = YieldingReader::new(vec![
325            Err(Error::new(ErrorKind::Other, "Some error")),
326            Ok(&[5, 6, 7][..]),
327        ]);
328
329        let mut chain = first.chain(second);
330
331        assert_eq!(block_on(chain.fill_buf()).unwrap(), [1, 2, 3]);
332        assert_eq!(block_on(chain.fill_buf()).unwrap(), [1, 2, 3]);
333
334        chain.consume(2);
335        assert_eq!(block_on(chain.fill_buf()).unwrap(), [3]);
336        assert_eq!(block_on(chain.fill_buf()).unwrap(), [3]);
337
338        chain.consume(1);
339        assert_eq!(block_on(chain.fill_buf()).unwrap(), [4]);
340        assert_eq!(block_on(chain.fill_buf()).unwrap(), [4]);
341
342        chain.consume(1);
343        assert_eq!(
344            block_on(chain.fill_buf()).unwrap_err().to_string(),
345            "Some error"
346        );
347        assert_eq!(block_on(chain.fill_buf()).unwrap(), [5, 6, 7]);
348        assert_eq!(block_on(chain.fill_buf()).unwrap(), [5, 6, 7]);
349
350        chain.consume(3);
351        assert_eq!(block_on(chain.fill_buf()).unwrap(), []);
352        assert_eq!(block_on(chain.fill_buf()).unwrap(), []);
353    }
354}