ctf_pwn/io/util/cache/
read_until.rs

1use crate::io::AsyncCacheRead;
2use pin_project_lite::pin_project;
3use std::future::Future;
4use std::io;
5use std::io::ErrorKind;
6use std::marker::PhantomPinned;
7use std::pin::Pin;
8use std::task::{ready, Context, Poll};
9use tokio::io::ReadBuf;
10
11pin_project! {
12    /// The delimiter is included in the resulting vector.
13    #[derive(Debug)]
14    #[must_use = "futures do nothing unless you `.await` or poll them"]
15    pub struct ReadUntil<'a, R: ?Sized, D:AsRef<[u8]>> {
16        reader: &'a mut R,
17        delimiter: D,
18        buf: &'a mut Vec<u8>,
19        internal_buf: Vec<u8>,
20        // The number of bytes appended to buf. This can be less than buf.len() if
21        // the buffer was not empty when the operation was started.
22        read: usize,
23        // Make this future `!Unpin` for compatibility with async trait methods.
24        #[pin]
25        _pin: PhantomPinned,
26    }
27}
28
29pub(crate) fn read_until<'a, R, D: AsRef<[u8]>>(
30    reader: &'a mut R,
31    delimiter: D,
32    buf: &'a mut Vec<u8>,
33) -> ReadUntil<'a, R, D>
34where
35    R: AsyncCacheRead + ?Sized + Unpin,
36{
37    ReadUntil {
38        reader,
39        delimiter,
40        buf,
41        internal_buf: Vec::new(),
42        read: 0,
43        _pin: PhantomPinned,
44    }
45}
46
47fn eof() -> io::Error {
48    io::Error::new(ErrorKind::UnexpectedEof, "early eof")
49}
50
51pub(super) fn read_until_internal<R: AsyncCacheRead + ?Sized, D: AsRef<[u8]>>(
52    mut reader: Pin<&mut R>,
53    cx: &mut Context<'_>,
54    delimiter: D,
55    buf: &mut Vec<u8>,
56    internal_buf: &mut Vec<u8>,
57    read: &mut usize,
58) -> Poll<io::Result<usize>> {
59    let delim_len = delimiter.as_ref().len();
60    if delim_len == 0 {
61        return Poll::Ready(Ok(0));
62    }
63
64    let mut read_buf = [0u8; 4096];
65    let mut data = ReadBuf::new(&mut read_buf);
66    loop {
67        data.clear();
68        match ready!(reader.as_mut().poll_read(cx, &mut data)) {
69            Ok(_) => {}
70            Err(e) if e.kind() == ErrorKind::TimedOut => {
71                continue;
72            }
73            Err(e) => {
74                return Poll::Ready(Err(e.into()));
75            }
76        }
77        let read_len = data.filled().len();
78        if read_len == 0 {
79            return Err(eof()).into();
80        }
81        *read += read_len;
82        internal_buf.extend_from_slice(data.filled());
83
84        match kmp::kmp_find(delimiter.as_ref(), &internal_buf) {
85            Some(offset) => {
86                let drain_index = offset + delim_len;
87                buf.extend_from_slice(&internal_buf[..drain_index]);
88                let restore_data = &internal_buf[drain_index..];
89                reader.restore(restore_data);
90                *read -= restore_data.len();
91                return Poll::Ready(Ok(buf.len()));
92            }
93            None => {
94                if internal_buf.len() >= delim_len {
95                    let drain_range = 0..internal_buf.len() - delim_len;
96                    buf.extend_from_slice(&internal_buf[drain_range.clone()]);
97                    internal_buf.drain(drain_range);
98                }
99            }
100        }
101    }
102}
103
104impl<R: AsyncCacheRead + ?Sized + Unpin, D: AsRef<[u8]>> Future for ReadUntil<'_, R, D> {
105    type Output = io::Result<usize>;
106
107    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
108        let me = self.project();
109        read_until_internal(
110            Pin::new(*me.reader),
111            cx,
112            me.delimiter,
113            me.buf,
114            me.internal_buf,
115            me.read,
116        )
117    }
118}