ctf_pwn/io/util/timeout/
read_timeout.rs1use crate::io::util::timeout::{eof, get_deadline, timeout};
2use pin_project_lite::pin_project;
3use std::future::Future;
4use std::io;
5use std::io::ErrorKind;
6use std::marker::PhantomPinned;
7use std::marker::Unpin;
8use std::pin::Pin;
9use std::task::{ready, Context, Poll};
10use std::time::Duration;
11use tokio::io::{AsyncRead, ReadBuf};
12use tokio::time::Instant;
13
14pub(crate) fn read_timeout<'a, A>(
16 reader: &'a mut A,
17 buf: &'a mut [u8],
18 timeout: Duration,
19) -> ReadTimeout<'a, A>
20where
21 A: AsyncRead + Unpin + ?Sized,
22{
23 let deadline = get_deadline(timeout);
24 ReadTimeout {
25 reader,
26 buf: ReadBuf::new(buf),
27 deadline,
28 _pin: PhantomPinned,
29 }
30}
31
32pin_project! {
33 #[derive(Debug)]
39 #[must_use = "futures do nothing unless you `.await` or poll them"]
40 pub struct ReadTimeout<'a, A: ?Sized> {
41 reader: &'a mut A,
42 buf: ReadBuf<'a>,
43 deadline: Instant,
44 #[pin]
46 _pin: PhantomPinned,
47 }
48}
49
50impl<A> Future for ReadTimeout<'_, A>
51where
52 A: AsyncRead + Unpin + ?Sized,
53{
54 type Output = io::Result<usize>;
55
56 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
57 let me = self.project();
58
59 loop {
60 if *me.deadline < Instant::now() {
61 return Poll::Ready(Err(timeout()));
62 }
63
64 let old_remaining = me.buf.remaining();
65 match Pin::new(&mut *me.reader).poll_read(cx, me.buf) {
66 Poll::Ready(Ok(_)) => {
67 if me.buf.remaining() == old_remaining {
68 return Err(eof()).into();
69 }
70 return Poll::Ready(Ok(me.buf.filled().len()));
71 }
72 Poll::Ready(Err(e)) if e.kind() == ErrorKind::TimedOut => {
73 continue;
74 }
75 Poll::Ready(Err(e)) => {
76 return Poll::Ready(Err(e.into()));
77 }
78 Poll::Pending => continue,
79 };
80 }
81 }
82}