1use crate::raw::{BodyPart, DefaultRawBody};
15use crate::BaseBody;
16use bytes::{Buf, Bytes, BytesMut};
17use conjure_error::Error;
18use futures::channel::mpsc;
19use futures::{ready, SinkExt, Stream};
20use http_body::{Body, Frame, SizeHint};
21use pin_project::pin_project;
22use std::marker::PhantomPinned;
23use std::pin::Pin;
24use std::task::{Context, Poll};
25use std::{error, io, mem};
26use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
27
28#[pin_project]
31pub struct BodyWriter {
32 #[pin]
33 sender: mpsc::Sender<BodyPart>,
34 buf: BytesMut,
35 #[pin]
36 _p: PhantomPinned,
37}
38
39impl BodyWriter {
40 pub(crate) fn new(sender: mpsc::Sender<BodyPart>) -> BodyWriter {
41 BodyWriter {
42 sender,
43 buf: BytesMut::new(),
44 _p: PhantomPinned,
45 }
46 }
47
48 pub(crate) async fn finish(mut self: Pin<&mut Self>) -> io::Result<()> {
49 self.flush().await?;
50 self.project()
51 .sender
52 .send(BodyPart::Done)
53 .await
54 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
55 Ok(())
56 }
57
58 pub async fn write_bytes(mut self: Pin<&mut Self>, bytes: Bytes) -> io::Result<()> {
63 self.flush().await?;
64 self.project()
65 .sender
66 .send(BodyPart::Frame(Frame::data(bytes)))
67 .await
68 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
69 Ok(())
70 }
71}
72
73impl AsyncWrite for BodyWriter {
74 fn poll_write(
75 mut self: Pin<&mut Self>,
76 cx: &mut Context<'_>,
77 buf: &[u8],
78 ) -> Poll<io::Result<usize>> {
79 if self.buf.len() > 4096 {
80 ready!(self.as_mut().poll_flush(cx))?;
81 }
82
83 self.project().buf.extend_from_slice(buf);
84 Poll::Ready(Ok(buf.len()))
85 }
86
87 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
88 let mut this = self.project();
89
90 if this.buf.is_empty() {
91 return Poll::Ready(Ok(()));
92 }
93
94 ready!(this.sender.poll_ready(cx)).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
95 let chunk = this.buf.split().freeze();
96 this.sender
97 .start_send(BodyPart::Frame(Frame::data(chunk)))
98 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
99
100 Poll::Ready(Ok(()))
101 }
102
103 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
104 Poll::Ready(Ok(()))
105 }
106}
107
108#[pin_project]
110pub struct ResponseBody<B = DefaultRawBody> {
111 #[pin]
112 body: FuseBody<BaseBody<B>>,
113 cur: Bytes,
114 #[pin]
116 _p: PhantomPinned,
117}
118
119impl<B> ResponseBody<B> {
120 pub(crate) fn new(body: BaseBody<B>) -> Self {
121 ResponseBody {
122 body: FuseBody::new(body),
123 cur: Bytes::new(),
124 _p: PhantomPinned,
125 }
126 }
127
128 pub(crate) fn buffer(&self) -> &[u8] {
129 &self.cur
130 }
131}
132
133impl<B> Stream for ResponseBody<B>
134where
135 B: Body<Data = Bytes>,
136 B::Error: Into<Box<dyn error::Error + Sync + Send>>,
137{
138 type Item = Result<Bytes, Error>;
139
140 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
141 let mut this = self.project();
142
143 if this.cur.has_remaining() {
144 return Poll::Ready(Some(Ok(mem::take(this.cur))));
145 }
146
147 loop {
148 match ready!(this.body.as_mut().poll_frame(cx))
149 .transpose()
150 .map_err(Error::internal_safe)?
151 {
152 Some(frame) => {
153 if let Ok(data) = frame.into_data() {
154 return Poll::Ready(Some(Ok(data)));
155 }
156 }
157 None => return Poll::Ready(None),
158 }
159 }
160 }
161}
162
163impl<B> AsyncRead for ResponseBody<B>
164where
165 B: Body<Data = Bytes>,
166 B::Error: Into<Box<dyn error::Error + Sync + Send>>,
167{
168 fn poll_read(
169 mut self: Pin<&mut Self>,
170 cx: &mut Context<'_>,
171 buf: &mut ReadBuf<'_>,
172 ) -> Poll<io::Result<()>> {
173 let in_buf = ready!(self.as_mut().poll_fill_buf(cx))?;
174 let len = usize::min(in_buf.len(), buf.remaining());
175 buf.put_slice(&in_buf[..len]);
176 self.consume(len);
177
178 Poll::Ready(Ok(()))
179 }
180}
181
182impl<B> AsyncBufRead for ResponseBody<B>
183where
184 B: Body<Data = Bytes>,
185 B::Error: Into<Box<dyn error::Error + Sync + Send>>,
186{
187 fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
188 while !self.cur.has_remaining() {
189 match ready!(self.as_mut().project().body.poll_frame(cx))
190 .transpose()
191 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
192 {
193 Some(frame) => {
194 if let Ok(data) = frame.into_data() {
195 *self.as_mut().project().cur = data;
196 }
197 }
198 None => break,
199 }
200 }
201
202 Poll::Ready(Ok(self.project().cur))
203 }
204
205 fn consume(self: Pin<&mut Self>, amt: usize) {
206 self.project().cur.advance(amt)
207 }
208}
209
210#[pin_project]
211struct FuseBody<B> {
212 #[pin]
213 body: B,
214 done: bool,
215}
216
217impl<B> FuseBody<B> {
218 fn new(body: B) -> FuseBody<B> {
219 FuseBody { body, done: false }
220 }
221}
222
223impl<B> Body for FuseBody<B>
224where
225 B: Body,
226{
227 type Data = B::Data;
228 type Error = B::Error;
229
230 fn poll_frame(
231 self: Pin<&mut Self>,
232 cx: &mut Context<'_>,
233 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
234 let this = self.project();
235
236 if *this.done {
237 return Poll::Ready(None);
238 }
239
240 let frame = ready!(this.body.poll_frame(cx));
241 if frame.is_none() {
242 *this.done = true;
243 }
244
245 Poll::Ready(frame)
246 }
247
248 fn is_end_stream(&self) -> bool {
249 self.done || self.body.is_end_stream()
250 }
251
252 fn size_hint(&self) -> SizeHint {
253 if self.done {
254 SizeHint::with_exact(0)
255 } else {
256 self.body.size_hint()
257 }
258 }
259}