async_http_codec/transaction/server/
body_decode_with_continue.rs

1use crate::common::length_from_headers;
2use crate::internal::buffer_write::BufferWriteState;
3use crate::internal::io_future::IoFutureState;
4use crate::{BodyDecodeState, RequestHead, ResponseHead};
5use futures::prelude::*;
6use http::header::EXPECT;
7use http::{HeaderMap, StatusCode, Version};
8use std::borrow::{BorrowMut, Cow};
9use std::io;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13pub struct BodyDecodeWithContinueState {
14    cont: Option<BufferWriteState>,
15    flushed_cont: bool,
16    body: BodyDecodeState,
17}
18
19impl BodyDecodeWithContinueState {
20    pub fn from_head(head: &RequestHead) -> anyhow::Result<Self> {
21        Ok(Self::from_headers(head.headers(), head.version())?)
22    }
23    pub fn new(version: Version, length: Option<u64>, send_continue: bool) -> Self {
24        Self {
25            cont: match send_continue {
26                true => Some(
27                    ResponseHead::new(StatusCode::CONTINUE, version, Cow::Owned(HeaderMap::new()))
28                        .encode_state(),
29                ),
30                false => None,
31            },
32            flushed_cont: false,
33            body: BodyDecodeState::new(length),
34        }
35    }
36    pub fn from_headers(
37        headers: &http::header::HeaderMap,
38        version: Version,
39    ) -> anyhow::Result<Self> {
40        Ok(Self::new(
41            version,
42            length_from_headers(headers)?,
43            contains_continue(headers),
44        ))
45    }
46    pub fn into_async_read<IO: AsyncRead + AsyncWrite + Unpin>(
47        self,
48        io: IO,
49    ) -> BodyDecodeWithContinue<Self, IO> {
50        BodyDecodeWithContinue { io, state: self }
51    }
52    pub fn as_async_read<IO: AsyncRead + AsyncWrite + Unpin>(
53        &mut self,
54        io: IO,
55    ) -> BodyDecodeWithContinue<&mut Self, IO> {
56        BodyDecodeWithContinue { io, state: self }
57    }
58    pub fn poll_read<IO: AsyncRead + AsyncWrite + Unpin>(
59        &mut self,
60        cx: &mut Context<'_>,
61        buf: &mut [u8],
62        io: &mut IO,
63    ) -> Poll<io::Result<usize>> {
64        loop {
65            if let Some(cont) = &mut self.cont {
66                match cont.poll(cx, io) {
67                    Poll::Ready(Ok(())) => self.cont.take(),
68                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
69                    Poll::Pending => return Poll::Pending,
70                };
71            }
72            if !self.flushed_cont {
73                match Pin::new(&mut *io).poll_flush(cx) {
74                    Poll::Ready(Ok(())) => self.flushed_cont = true,
75                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
76                    Poll::Pending => return Poll::Pending,
77                }
78            }
79            return self.body.poll_read(io, cx, buf);
80        }
81    }
82}
83
84pub struct BodyDecodeWithContinue<
85    T: BorrowMut<BodyDecodeWithContinueState> + Unpin,
86    IO: AsyncRead + AsyncWrite + Unpin,
87> {
88    io: IO,
89    state: T,
90}
91
92impl<IO: AsyncRead + AsyncWrite + Unpin> BodyDecodeWithContinue<BodyDecodeWithContinueState, IO> {
93    pub fn from_head(head: &RequestHead, io: IO) -> anyhow::Result<Self> {
94        Ok(BodyDecodeWithContinueState::from_head(head)?.into_async_read(io))
95    }
96    pub fn from_headers(
97        headers: &http::header::HeaderMap,
98        version: Version,
99        io: IO,
100    ) -> anyhow::Result<Self> {
101        Ok(BodyDecodeWithContinueState::from_headers(headers, version)?.into_async_read(io))
102    }
103    pub fn new(io: IO, version: Version, length: Option<u64>, send_continue: bool) -> Self {
104        BodyDecodeWithContinueState::new(version, length, send_continue).into_async_read(io)
105    }
106}
107
108impl<T: BorrowMut<BodyDecodeWithContinueState> + Unpin, IO: AsyncRead + AsyncWrite + Unpin>
109    BodyDecodeWithContinue<T, IO>
110{
111    pub fn into_inner(self) -> (T, IO) {
112        (self.state, self.io)
113    }
114}
115
116impl<T: BorrowMut<BodyDecodeWithContinueState> + Unpin, IO: AsyncRead + AsyncWrite + Unpin>
117    AsyncRead for BodyDecodeWithContinue<T, IO>
118{
119    fn poll_read(
120        self: Pin<&mut Self>,
121        cx: &mut Context<'_>,
122        buf: &mut [u8],
123    ) -> Poll<io::Result<usize>> {
124        let this = self.get_mut();
125        this.state
126            .borrow_mut()
127            .poll_read(cx, buf, this.io.borrow_mut())
128    }
129}
130
131pub(crate) fn contains_continue(headers: &HeaderMap) -> bool {
132    headers
133        .get_all(EXPECT)
134        .iter()
135        .find(|v| v == &"100-continue")
136        .is_some()
137}