conjure_runtime/
body.rs

1// Copyright 2020 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use crate::raw::{BodyPart, DefaultRawBody};
15use crate::BaseBody;
16use bytes::{Buf, Bytes, BytesMut};
17use conjure_error::Error;
18use futures::channel::mpsc;
19use futures::{ready, SinkExt, Stream};
20use http_body::{Body, Frame, SizeHint};
21use pin_project::pin_project;
22use std::marker::PhantomPinned;
23use std::pin::Pin;
24use std::task::{Context, Poll};
25use std::{error, io, mem};
26use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
27
28/// The asynchronous writer passed to
29/// [`AsyncWriteBody::write_body()`](conjure_http::client::AsyncWriteBody::write_body()).
30#[pin_project]
31pub struct BodyWriter {
32    #[pin]
33    sender: mpsc::Sender<BodyPart>,
34    buf: BytesMut,
35    #[pin]
36    _p: PhantomPinned,
37}
38
39impl BodyWriter {
40    pub(crate) fn new(sender: mpsc::Sender<BodyPart>) -> BodyWriter {
41        BodyWriter {
42            sender,
43            buf: BytesMut::new(),
44            _p: PhantomPinned,
45        }
46    }
47
48    pub(crate) async fn finish(mut self: Pin<&mut Self>) -> io::Result<()> {
49        self.flush().await?;
50        self.project()
51            .sender
52            .send(BodyPart::Done)
53            .await
54            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
55        Ok(())
56    }
57
58    /// Writes a block of body bytes.
59    ///
60    /// Compared to the [`AsyncWrite`] implementation, this method avoids some copies if the caller already has the body
61    /// in [`Bytes`] objects.
62    pub async fn write_bytes(mut self: Pin<&mut Self>, bytes: Bytes) -> io::Result<()> {
63        self.flush().await?;
64        self.project()
65            .sender
66            .send(BodyPart::Frame(Frame::data(bytes)))
67            .await
68            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
69        Ok(())
70    }
71}
72
73impl AsyncWrite for BodyWriter {
74    fn poll_write(
75        mut self: Pin<&mut Self>,
76        cx: &mut Context<'_>,
77        buf: &[u8],
78    ) -> Poll<io::Result<usize>> {
79        if self.buf.len() > 4096 {
80            ready!(self.as_mut().poll_flush(cx))?;
81        }
82
83        self.project().buf.extend_from_slice(buf);
84        Poll::Ready(Ok(buf.len()))
85    }
86
87    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
88        let mut this = self.project();
89
90        if this.buf.is_empty() {
91            return Poll::Ready(Ok(()));
92        }
93
94        ready!(this.sender.poll_ready(cx)).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
95        let chunk = this.buf.split().freeze();
96        this.sender
97            .start_send(BodyPart::Frame(Frame::data(chunk)))
98            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
99
100        Poll::Ready(Ok(()))
101    }
102
103    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
104        Poll::Ready(Ok(()))
105    }
106}
107
108/// An asynchronous streaming response body.
109#[pin_project]
110pub struct ResponseBody<B = DefaultRawBody> {
111    #[pin]
112    body: FuseBody<BaseBody<B>>,
113    cur: Bytes,
114    // Make sure we can make our internal BaseBody !Unpin in the future if we want
115    #[pin]
116    _p: PhantomPinned,
117}
118
119impl<B> ResponseBody<B> {
120    pub(crate) fn new(body: BaseBody<B>) -> Self {
121        ResponseBody {
122            body: FuseBody::new(body),
123            cur: Bytes::new(),
124            _p: PhantomPinned,
125        }
126    }
127
128    pub(crate) fn buffer(&self) -> &[u8] {
129        &self.cur
130    }
131}
132
133impl<B> Stream for ResponseBody<B>
134where
135    B: Body<Data = Bytes>,
136    B::Error: Into<Box<dyn error::Error + Sync + Send>>,
137{
138    type Item = Result<Bytes, Error>;
139
140    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
141        let mut this = self.project();
142
143        if this.cur.has_remaining() {
144            return Poll::Ready(Some(Ok(mem::take(this.cur))));
145        }
146
147        loop {
148            match ready!(this.body.as_mut().poll_frame(cx))
149                .transpose()
150                .map_err(Error::internal_safe)?
151            {
152                Some(frame) => {
153                    if let Ok(data) = frame.into_data() {
154                        return Poll::Ready(Some(Ok(data)));
155                    }
156                }
157                None => return Poll::Ready(None),
158            }
159        }
160    }
161}
162
163impl<B> AsyncRead for ResponseBody<B>
164where
165    B: Body<Data = Bytes>,
166    B::Error: Into<Box<dyn error::Error + Sync + Send>>,
167{
168    fn poll_read(
169        mut self: Pin<&mut Self>,
170        cx: &mut Context<'_>,
171        buf: &mut ReadBuf<'_>,
172    ) -> Poll<io::Result<()>> {
173        let in_buf = ready!(self.as_mut().poll_fill_buf(cx))?;
174        let len = usize::min(in_buf.len(), buf.remaining());
175        buf.put_slice(&in_buf[..len]);
176        self.consume(len);
177
178        Poll::Ready(Ok(()))
179    }
180}
181
182impl<B> AsyncBufRead for ResponseBody<B>
183where
184    B: Body<Data = Bytes>,
185    B::Error: Into<Box<dyn error::Error + Sync + Send>>,
186{
187    fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
188        while !self.cur.has_remaining() {
189            match ready!(self.as_mut().project().body.poll_frame(cx))
190                .transpose()
191                .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
192            {
193                Some(frame) => {
194                    if let Ok(data) = frame.into_data() {
195                        *self.as_mut().project().cur = data;
196                    }
197                }
198                None => break,
199            }
200        }
201
202        Poll::Ready(Ok(self.project().cur))
203    }
204
205    fn consume(self: Pin<&mut Self>, amt: usize) {
206        self.project().cur.advance(amt)
207    }
208}
209
210#[pin_project]
211struct FuseBody<B> {
212    #[pin]
213    body: B,
214    done: bool,
215}
216
217impl<B> FuseBody<B> {
218    fn new(body: B) -> FuseBody<B> {
219        FuseBody { body, done: false }
220    }
221}
222
223impl<B> Body for FuseBody<B>
224where
225    B: Body,
226{
227    type Data = B::Data;
228    type Error = B::Error;
229
230    fn poll_frame(
231        self: Pin<&mut Self>,
232        cx: &mut Context<'_>,
233    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
234        let this = self.project();
235
236        if *this.done {
237            return Poll::Ready(None);
238        }
239
240        let frame = ready!(this.body.poll_frame(cx));
241        if frame.is_none() {
242            *this.done = true;
243        }
244
245        Poll::Ready(frame)
246    }
247
248    fn is_end_stream(&self) -> bool {
249        self.done || self.body.is_end_stream()
250    }
251
252    fn size_hint(&self) -> SizeHint {
253        if self.done {
254            SizeHint::with_exact(0)
255        } else {
256            self.body.size_hint()
257        }
258    }
259}