async_http_codec/internal/
buffer_decode.rs

1use crate::internal::io_future::{IoFutureWithOutput, IoFutureWithOutputState};
2use crate::RequestHead;
3use futures::prelude::*;
4use std::io;
5use std::io::ErrorKind::InvalidData;
6use std::marker::PhantomData;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10pub struct BufferDecodeState<O: 'static> {
11    buffer: Vec<u8>,
12    completion: usize,
13    max_headers: usize,
14    decode_func: &'static (dyn Fn(&[u8], usize) -> io::Result<O> + Sync),
15    _phantom: PhantomData<&'static O>,
16}
17
18impl<O> BufferDecodeState<O> {
19    pub fn new(
20        max_buffer: usize,
21        max_headers: usize,
22        decode_func: &'static (dyn Fn(&[u8], usize) -> io::Result<O> + Sync),
23    ) -> Self {
24        Self {
25            buffer: Vec::with_capacity(max_buffer),
26            completion: 0,
27            max_headers,
28            decode_func,
29            _phantom: Default::default(),
30        }
31    }
32}
33
34impl<IO: AsyncRead + Unpin, O> IoFutureWithOutputState<IO, O> for BufferDecodeState<O> {
35    fn poll(&mut self, cx: &mut Context<'_>, transport: &mut IO) -> Poll<io::Result<O>> {
36        const END: &[u8; 4] = b"\r\n\r\n";
37        let mut chunk = [0u8; END.len()];
38        loop {
39            let chunk = &mut chunk[self.completion..4];
40            if self.buffer.len() + chunk.len() > self.buffer.capacity() {
41                return Poll::Ready(Err(io::Error::new(InvalidData, "head too long")));
42            }
43            match Pin::new(&mut *transport).poll_read(cx, chunk) {
44                Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())),
45                Poll::Ready(Ok(n)) => {
46                    let mut chunk = &chunk[0..n];
47                    self.buffer.extend_from_slice(chunk);
48                    while self.completion == 0 && chunk.len() > 0 {
49                        if chunk[0] == END[0] {
50                            self.completion = 1
51                        }
52                        chunk = &chunk[1..];
53                    }
54                    match chunk == &END[self.completion..self.completion + chunk.len()] {
55                        true => self.completion += chunk.len(),
56                        false => self.completion = 0,
57                    }
58                    if self.completion == END.len() {
59                        break;
60                    }
61                }
62                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
63                Poll::Pending => return Poll::Pending,
64            }
65        }
66        return Poll::Ready((self.decode_func)(&self.buffer, self.max_headers));
67    }
68}
69
70pub type BufferDecode<IO, O> = IoFutureWithOutput<BufferDecodeState<O>, IO, O>;
71
72#[allow(dead_code)]
73const fn check_if_send<T: Send>() {}
74const _: () = check_if_send::<BufferDecode<Box<dyn AsyncRead + Send + Unpin>, RequestHead>>();