ctf_pwn/io/util/timeout/
read_timeout.rs

1use crate::io::util::timeout::{eof, get_deadline, timeout};
2use pin_project_lite::pin_project;
3use std::future::Future;
4use std::io;
5use std::io::ErrorKind;
6use std::marker::PhantomPinned;
7use std::marker::Unpin;
8use std::pin::Pin;
9use std::task::{ready, Context, Poll};
10use std::time::Duration;
11use tokio::io::{AsyncRead, ReadBuf};
12use tokio::time::Instant;
13
14/// A future which can be used to easily read bytes until timeout or buf is fully filled
15pub(crate) fn read_timeout<'a, A>(
16    reader: &'a mut A,
17    buf: &'a mut [u8],
18    timeout: Duration,
19) -> ReadTimeout<'a, A>
20where
21    A: AsyncRead + Unpin + ?Sized,
22{
23    let deadline = get_deadline(timeout);
24    ReadTimeout {
25        reader,
26        buf: ReadBuf::new(buf),
27        deadline,
28        _pin: PhantomPinned,
29    }
30}
31
32pin_project! {
33    /// Creates a future which will read exactly enough bytes to fill `buf`,
34    /// stops if Timeout,
35    /// returning an error if EOF is hit sooner.
36    ///
37    /// On success the number of bytes is returned
38    #[derive(Debug)]
39    #[must_use = "futures do nothing unless you `.await` or poll them"]
40    pub struct ReadTimeout<'a, A: ?Sized> {
41        reader: &'a mut A,
42        buf: ReadBuf<'a>,
43        deadline: Instant,
44        // Make this future `!Unpin` for compatibility with async trait methods.
45        #[pin]
46        _pin: PhantomPinned,
47    }
48}
49
50impl<A> Future for ReadTimeout<'_, A>
51where
52    A: AsyncRead + Unpin + ?Sized,
53{
54    type Output = io::Result<usize>;
55
56    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
57        let me = self.project();
58
59        loop {
60            if *me.deadline < Instant::now() {
61                return Poll::Ready(Err(timeout()));
62            }
63
64            let old_remaining = me.buf.remaining();
65            match Pin::new(&mut *me.reader).poll_read(cx, me.buf) {
66                Poll::Ready(Ok(_)) => {
67                    if me.buf.remaining() == old_remaining {
68                        return Err(eof()).into();
69                    }
70                    return Poll::Ready(Ok(me.buf.filled().len()));
71                }
72                Poll::Ready(Err(e)) if e.kind() == ErrorKind::TimedOut => {
73                    continue;
74                }
75                Poll::Ready(Err(e)) => {
76                    return Poll::Ready(Err(e.into()));
77                }
78                Poll::Pending => continue,
79            };
80        }
81    }
82}