conjure_runtime/raw/
body.rs

1// Copyright 2020 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use crate::BodyWriter;
15use bytes::Bytes;
16use conjure_error::Error;
17use conjure_http::client::{AsyncRequestBody, AsyncWriteBody, BoxAsyncWriteBody};
18use futures::channel::{mpsc, oneshot};
19use futures::{pin_mut, Stream};
20use http_body::{Frame, SizeHint};
21use std::pin::Pin;
22use std::task::{Context, Poll};
23use std::{error, fmt, mem};
24use witchcraft_log::debug;
25
26/// The error type returned by `RawBody`.
27#[derive(Debug)]
28pub struct BodyError(());
29
30impl fmt::Display for BodyError {
31    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
32        fmt.write_str("error writing body")
33    }
34}
35
36impl error::Error for BodyError {}
37
38pub(crate) enum BodyPart {
39    Frame(Frame<Bytes>),
40    Done,
41}
42
43pub(crate) enum RawBodyInner {
44    Empty,
45    Single(Frame<Bytes>),
46    Stream {
47        receiver: mpsc::Receiver<BodyPart>,
48        polled: Option<oneshot::Sender<()>>,
49    },
50}
51
52/// The request body type passed to the raw HTTP client.
53pub struct RawBody {
54    pub(crate) inner: RawBodyInner,
55}
56
57impl RawBody {
58    pub(crate) fn new(body: AsyncRequestBody<'_, BodyWriter>) -> (RawBody, Writer<'_>) {
59        match body {
60            AsyncRequestBody::Empty => (
61                RawBody {
62                    inner: RawBodyInner::Empty,
63                },
64                Writer::Nop,
65            ),
66            AsyncRequestBody::Fixed(body) => (
67                RawBody {
68                    inner: RawBodyInner::Single(Frame::data(body)),
69                },
70                Writer::Nop,
71            ),
72            AsyncRequestBody::Streaming(body) => {
73                let (body_sender, body_receiver) = mpsc::channel(1);
74                let (polled_sender, polled_receiver) = oneshot::channel();
75                (
76                    RawBody {
77                        inner: RawBodyInner::Stream {
78                            receiver: body_receiver,
79                            polled: Some(polled_sender),
80                        },
81                    },
82                    Writer::Streaming {
83                        polled: polled_receiver,
84                        body,
85                        sender: body_sender,
86                    },
87                )
88            }
89        }
90    }
91}
92
93impl http_body::Body for RawBody {
94    type Data = Bytes;
95    type Error = BodyError;
96
97    fn poll_frame(
98        mut self: Pin<&mut Self>,
99        cx: &mut Context<'_>,
100    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
101        match mem::replace(&mut self.inner, RawBodyInner::Empty) {
102            RawBodyInner::Empty => Poll::Ready(None),
103            RawBodyInner::Single(frame) => Poll::Ready(Some(Ok(frame))),
104            RawBodyInner::Stream {
105                mut receiver,
106                mut polled,
107            } => {
108                if let Some(polled) = polled.take() {
109                    let _ = polled.send(());
110                }
111
112                match Pin::new(&mut receiver).poll_next(cx) {
113                    Poll::Ready(Some(BodyPart::Frame(frame))) => {
114                        self.inner = RawBodyInner::Stream { receiver, polled };
115                        Poll::Ready(Some(Ok(frame)))
116                    }
117                    Poll::Ready(Some(BodyPart::Done)) => Poll::Ready(None),
118                    Poll::Ready(None) => Poll::Ready(Some(Err(BodyError(())))),
119                    Poll::Pending => {
120                        self.inner = RawBodyInner::Stream { receiver, polled };
121                        Poll::Pending
122                    }
123                }
124            }
125        }
126    }
127
128    fn is_end_stream(&self) -> bool {
129        matches!(self.inner, RawBodyInner::Empty)
130    }
131
132    fn size_hint(&self) -> SizeHint {
133        match &self.inner {
134            RawBodyInner::Empty => SizeHint::with_exact(0),
135            RawBodyInner::Single(frame) => {
136                let len = match frame.data_ref() {
137                    Some(buf) => buf.len(),
138                    None => 0,
139                };
140                SizeHint::with_exact(len as u64)
141            }
142            RawBodyInner::Stream { .. } => SizeHint::new(),
143        }
144    }
145}
146
147pub(crate) enum Writer<'a> {
148    Nop,
149    Streaming {
150        polled: oneshot::Receiver<()>,
151        body: BoxAsyncWriteBody<'a, BodyWriter>,
152        sender: mpsc::Sender<BodyPart>,
153    },
154}
155
156impl Writer<'_> {
157    pub async fn write(self) -> Result<(), Error> {
158        match self {
159            Writer::Nop => Ok(()),
160            Writer::Streaming {
161                polled,
162                mut body,
163                sender,
164            } => {
165                // wait for hyper to actually ask for the body so we don't start reading it if the request fails early
166                if polled.await.is_err() {
167                    debug!("hyper hung up before polling request body");
168                    return Ok(());
169                }
170
171                let writer = BodyWriter::new(sender);
172                pin_mut!(writer);
173                Pin::new(&mut body).write_body(writer.as_mut()).await?;
174                writer.finish().await.map_err(Error::internal_safe)?;
175
176                Ok(())
177            }
178        }
179    }
180}