ctf_pwn/io/util/timeout/
read_exact_timeout.rs

1use crate::io::util::timeout::{eof, get_deadline, timeout};
2use pin_project_lite::pin_project;
3use std::future::Future;
4use std::io;
5use std::io::ErrorKind;
6use std::marker::PhantomPinned;
7use std::marker::Unpin;
8use std::pin::Pin;
9use std::task::{ready, Context, Poll};
10use std::time::Duration;
11use tokio::io::{AsyncRead, ReadBuf};
12use tokio::time::Instant;
13
14/// A future which can be used to easily read bytes until timeout or buf is fully filled
15pub(crate) fn read_exact_timeout<'a, A>(
16    reader: &'a mut A,
17    buf: &'a mut [u8],
18    timeout: Duration,
19    throw_on_timeout: bool,
20) -> ReadExactTimeout<'a, A>
21where
22    A: AsyncRead + Unpin + ?Sized,
23{
24    let deadline = get_deadline(timeout);
25    ReadExactTimeout {
26        reader,
27        buf: ReadBuf::new(buf),
28        deadline,
29        _pin: PhantomPinned,
30        throw_on_timeout,
31    }
32}
33
34pin_project! {
35    /// Creates a future which will read exactly enough bytes to fill `buf`,
36    /// stops if Timeout,
37    /// returning an error if EOF is hit sooner.
38    ///
39    /// On success the number of bytes is returned
40    #[derive(Debug)]
41    #[must_use = "futures do nothing unless you `.await` or poll them"]
42    pub struct ReadExactTimeout<'a, A: ?Sized> {
43        reader: &'a mut A,
44        buf: ReadBuf<'a>,
45        deadline: Instant,
46        // Make this future `!Unpin` for compatibility with async trait methods.
47        #[pin]
48        _pin: PhantomPinned,
49        throw_on_timeout: bool,
50    }
51}
52
53impl<A> Future for ReadExactTimeout<'_, A>
54where
55    A: AsyncRead + Unpin + ?Sized,
56{
57    type Output = io::Result<usize>;
58
59    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
60        let me = self.project();
61
62        loop {
63            if *me.deadline < Instant::now() {
64                if *me.throw_on_timeout {
65                    return Poll::Ready(Err(timeout().into()));
66                }
67                return Poll::Ready(Ok(me.buf.filled().len()));
68            }
69
70            // if our buffer is empty, then we need to read some data to continue.
71            let rem = me.buf.remaining();
72            if rem != 0 {
73                match Pin::new(&mut *me.reader).poll_read(cx, me.buf) {
74                    Poll::Ready(Ok(_)) => {}
75                    Poll::Ready(Err(e)) if e.kind() == ErrorKind::TimedOut => {
76                        continue;
77                    }
78                    Poll::Ready(Err(e)) => {
79                        return Poll::Ready(Err(e.into()));
80                    }
81                    Poll::Pending => continue,
82                };
83                if me.buf.remaining() == rem {
84                    return Err(eof()).into();
85                }
86            } else {
87                return Poll::Ready(Ok(me.buf.capacity()));
88            }
89        }
90    }
91}