conjure_runtime/
body.rs

1// Copyright 2020 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use crate::service::raw::RequestBodyPart;
15use crate::BaseBody;
16use bytes::{Buf, Bytes, BytesMut};
17use conjure_error::Error;
18use futures::channel::mpsc;
19use futures::{ready, SinkExt, Stream};
20use http_body::{Body, Frame, SizeHint};
21use pin_project::pin_project;
22use std::marker::PhantomPinned;
23use std::pin::Pin;
24use std::task::{Context, Poll};
25use std::{io, mem};
26use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
27
28/// The asynchronous writer passed to
29/// [`AsyncWriteBody::write_body()`](conjure_http::client::AsyncWriteBody::write_body()).
30#[pin_project]
31pub struct BodyWriter {
32    #[pin]
33    sender: mpsc::Sender<RequestBodyPart>,
34    buf: BytesMut,
35    #[pin]
36    _p: PhantomPinned,
37}
38
39impl BodyWriter {
40    pub(crate) fn new(sender: mpsc::Sender<RequestBodyPart>) -> BodyWriter {
41        BodyWriter {
42            sender,
43            buf: BytesMut::new(),
44            _p: PhantomPinned,
45        }
46    }
47
48    pub(crate) async fn finish(mut self: Pin<&mut Self>) -> io::Result<()> {
49        self.flush().await?;
50        self.project()
51            .sender
52            .send(RequestBodyPart::Done)
53            .await
54            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
55        Ok(())
56    }
57
58    /// Writes a block of body bytes.
59    ///
60    /// Compared to the [`AsyncWrite`] implementation, this method avoids some copies if the caller already has the body
61    /// in [`Bytes`] objects.
62    pub async fn write_bytes(mut self: Pin<&mut Self>, bytes: Bytes) -> io::Result<()> {
63        self.flush().await?;
64        self.project()
65            .sender
66            .send(RequestBodyPart::Frame(Frame::data(bytes)))
67            .await
68            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
69        Ok(())
70    }
71}
72
73impl AsyncWrite for BodyWriter {
74    fn poll_write(
75        mut self: Pin<&mut Self>,
76        cx: &mut Context<'_>,
77        buf: &[u8],
78    ) -> Poll<io::Result<usize>> {
79        if self.buf.len() > 4096 {
80            ready!(self.as_mut().poll_flush(cx))?;
81        }
82
83        self.project().buf.extend_from_slice(buf);
84        Poll::Ready(Ok(buf.len()))
85    }
86
87    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
88        let mut this = self.project();
89
90        if this.buf.is_empty() {
91            return Poll::Ready(Ok(()));
92        }
93
94        ready!(this.sender.poll_ready(cx)).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
95        let chunk = this.buf.split().freeze();
96        this.sender
97            .start_send(RequestBodyPart::Frame(Frame::data(chunk)))
98            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
99
100        Poll::Ready(Ok(()))
101    }
102
103    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
104        Poll::Ready(Ok(()))
105    }
106}
107
108/// An asynchronous streaming response body.
109#[pin_project]
110pub struct ResponseBody {
111    #[pin]
112    body: FuseBody<BaseBody>,
113    cur: Bytes,
114    // Make sure we can make our internal BaseBody !Unpin in the future if we want
115    #[pin]
116    _p: PhantomPinned,
117}
118
119impl ResponseBody {
120    pub(crate) fn new(body: BaseBody) -> Self {
121        ResponseBody {
122            body: FuseBody::new(body),
123            cur: Bytes::new(),
124            _p: PhantomPinned,
125        }
126    }
127
128    #[cfg(not(target_arch = "wasm32"))]
129    pub(crate) fn buffer(&self) -> &[u8] {
130        &self.cur
131    }
132}
133
134impl Stream for ResponseBody {
135    type Item = Result<Bytes, Error>;
136
137    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
138        let mut this = self.project();
139
140        if this.cur.has_remaining() {
141            return Poll::Ready(Some(Ok(mem::take(this.cur))));
142        }
143
144        loop {
145            match ready!(this.body.as_mut().poll_frame(cx))
146                .transpose()
147                .map_err(Error::internal_safe)?
148            {
149                Some(frame) => {
150                    if let Ok(data) = frame.into_data() {
151                        return Poll::Ready(Some(Ok(data)));
152                    }
153                }
154                None => return Poll::Ready(None),
155            }
156        }
157    }
158}
159
160impl AsyncRead for ResponseBody {
161    fn poll_read(
162        mut self: Pin<&mut Self>,
163        cx: &mut Context<'_>,
164        buf: &mut ReadBuf<'_>,
165    ) -> Poll<io::Result<()>> {
166        let in_buf = ready!(self.as_mut().poll_fill_buf(cx))?;
167        let len = usize::min(in_buf.len(), buf.remaining());
168        buf.put_slice(&in_buf[..len]);
169        self.consume(len);
170
171        Poll::Ready(Ok(()))
172    }
173}
174
175impl AsyncBufRead for ResponseBody {
176    fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
177        while !self.cur.has_remaining() {
178            match ready!(self.as_mut().project().body.poll_frame(cx))
179                .transpose()
180                .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
181            {
182                Some(frame) => {
183                    if let Ok(data) = frame.into_data() {
184                        *self.as_mut().project().cur = data;
185                    }
186                }
187                None => break,
188            }
189        }
190
191        Poll::Ready(Ok(self.project().cur))
192    }
193
194    fn consume(self: Pin<&mut Self>, amt: usize) {
195        self.project().cur.advance(amt)
196    }
197}
198
199#[pin_project]
200struct FuseBody<B> {
201    #[pin]
202    body: B,
203    done: bool,
204}
205
206impl<B> FuseBody<B> {
207    fn new(body: B) -> FuseBody<B> {
208        FuseBody { body, done: false }
209    }
210}
211
212impl<B> Body for FuseBody<B>
213where
214    B: Body,
215{
216    type Data = B::Data;
217    type Error = B::Error;
218
219    fn poll_frame(
220        self: Pin<&mut Self>,
221        cx: &mut Context<'_>,
222    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
223        let this = self.project();
224
225        if *this.done {
226            return Poll::Ready(None);
227        }
228
229        let frame = ready!(this.body.poll_frame(cx));
230        if frame.is_none() {
231            *this.done = true;
232        }
233
234        Poll::Ready(frame)
235    }
236
237    fn is_end_stream(&self) -> bool {
238        self.done || self.body.is_end_stream()
239    }
240
241    fn size_hint(&self) -> SizeHint {
242        if self.done {
243            SizeHint::with_exact(0)
244        } else {
245            self.body.size_hint()
246        }
247    }
248}