async_http_codec/transaction/server/
body_decode_with_continue.rs1use crate::common::length_from_headers;
2use crate::internal::buffer_write::BufferWriteState;
3use crate::internal::io_future::IoFutureState;
4use crate::{BodyDecodeState, RequestHead, ResponseHead};
5use futures::prelude::*;
6use http::header::EXPECT;
7use http::{HeaderMap, StatusCode, Version};
8use std::borrow::{BorrowMut, Cow};
9use std::io;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13pub struct BodyDecodeWithContinueState {
14 cont: Option<BufferWriteState>,
15 flushed_cont: bool,
16 body: BodyDecodeState,
17}
18
19impl BodyDecodeWithContinueState {
20 pub fn from_head(head: &RequestHead) -> anyhow::Result<Self> {
21 Ok(Self::from_headers(head.headers(), head.version())?)
22 }
23 pub fn new(version: Version, length: Option<u64>, send_continue: bool) -> Self {
24 Self {
25 cont: match send_continue {
26 true => Some(
27 ResponseHead::new(StatusCode::CONTINUE, version, Cow::Owned(HeaderMap::new()))
28 .encode_state(),
29 ),
30 false => None,
31 },
32 flushed_cont: false,
33 body: BodyDecodeState::new(length),
34 }
35 }
36 pub fn from_headers(
37 headers: &http::header::HeaderMap,
38 version: Version,
39 ) -> anyhow::Result<Self> {
40 Ok(Self::new(
41 version,
42 length_from_headers(headers)?,
43 contains_continue(headers),
44 ))
45 }
46 pub fn into_async_read<IO: AsyncRead + AsyncWrite + Unpin>(
47 self,
48 io: IO,
49 ) -> BodyDecodeWithContinue<Self, IO> {
50 BodyDecodeWithContinue { io, state: self }
51 }
52 pub fn as_async_read<IO: AsyncRead + AsyncWrite + Unpin>(
53 &mut self,
54 io: IO,
55 ) -> BodyDecodeWithContinue<&mut Self, IO> {
56 BodyDecodeWithContinue { io, state: self }
57 }
58 pub fn poll_read<IO: AsyncRead + AsyncWrite + Unpin>(
59 &mut self,
60 cx: &mut Context<'_>,
61 buf: &mut [u8],
62 io: &mut IO,
63 ) -> Poll<io::Result<usize>> {
64 loop {
65 if let Some(cont) = &mut self.cont {
66 match cont.poll(cx, io) {
67 Poll::Ready(Ok(())) => self.cont.take(),
68 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
69 Poll::Pending => return Poll::Pending,
70 };
71 }
72 if !self.flushed_cont {
73 match Pin::new(&mut *io).poll_flush(cx) {
74 Poll::Ready(Ok(())) => self.flushed_cont = true,
75 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
76 Poll::Pending => return Poll::Pending,
77 }
78 }
79 return self.body.poll_read(io, cx, buf);
80 }
81 }
82}
83
84pub struct BodyDecodeWithContinue<
85 T: BorrowMut<BodyDecodeWithContinueState> + Unpin,
86 IO: AsyncRead + AsyncWrite + Unpin,
87> {
88 io: IO,
89 state: T,
90}
91
92impl<IO: AsyncRead + AsyncWrite + Unpin> BodyDecodeWithContinue<BodyDecodeWithContinueState, IO> {
93 pub fn from_head(head: &RequestHead, io: IO) -> anyhow::Result<Self> {
94 Ok(BodyDecodeWithContinueState::from_head(head)?.into_async_read(io))
95 }
96 pub fn from_headers(
97 headers: &http::header::HeaderMap,
98 version: Version,
99 io: IO,
100 ) -> anyhow::Result<Self> {
101 Ok(BodyDecodeWithContinueState::from_headers(headers, version)?.into_async_read(io))
102 }
103 pub fn new(io: IO, version: Version, length: Option<u64>, send_continue: bool) -> Self {
104 BodyDecodeWithContinueState::new(version, length, send_continue).into_async_read(io)
105 }
106}
107
108impl<T: BorrowMut<BodyDecodeWithContinueState> + Unpin, IO: AsyncRead + AsyncWrite + Unpin>
109 BodyDecodeWithContinue<T, IO>
110{
111 pub fn into_inner(self) -> (T, IO) {
112 (self.state, self.io)
113 }
114}
115
116impl<T: BorrowMut<BodyDecodeWithContinueState> + Unpin, IO: AsyncRead + AsyncWrite + Unpin>
117 AsyncRead for BodyDecodeWithContinue<T, IO>
118{
119 fn poll_read(
120 self: Pin<&mut Self>,
121 cx: &mut Context<'_>,
122 buf: &mut [u8],
123 ) -> Poll<io::Result<usize>> {
124 let this = self.get_mut();
125 this.state
126 .borrow_mut()
127 .poll_read(cx, buf, this.io.borrow_mut())
128 }
129}
130
131pub(crate) fn contains_continue(headers: &HeaderMap) -> bool {
132 headers
133 .get_all(EXPECT)
134 .iter()
135 .find(|v| v == &"100-continue")
136 .is_some()
137}