fire_http_representation/body/
async_bytes_streamer.rs

1use super::{
2	size_limit_reached, timed_out, BoxedSyncRead, Constraints,
3	HyperBodyAsAsyncBytesStream, PinnedAsyncBytesStream, PinnedAsyncRead,
4};
5
6use std::future::Future;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::{io, mem};
10
11use tokio::time::Sleep;
12use tokio_stream::StreamExt;
13use tokio_util::io::ReaderStream;
14
15use futures_core::Stream;
16
17use pin_project_lite::pin_project;
18
19use bytes::{Bytes, BytesMut};
20
21pin_project! {
22	pub struct BodyAsyncBytesStreamer {
23		#[pin]
24		inner: ConstrainedAsyncBytesStreamer<Inner>
25	}
26}
27
28impl BodyAsyncBytesStreamer {
29	pub(super) fn new(inner: super::Inner, constraints: Constraints) -> Self {
30		let inner = match inner {
31			super::Inner::Empty => Inner::Empty,
32			super::Inner::Bytes(b) => Inner::Bytes(b),
33			super::Inner::Hyper(i) => {
34				Inner::Hyper(HyperBodyAsAsyncBytesStream::new(i))
35			}
36			super::Inner::SyncReader(r) => Inner::SyncReader {
37				reader: r,
38				buf: BytesMut::zeroed(DEFAULT_CAP),
39			},
40			super::Inner::AsyncReader(r) => {
41				Inner::AsyncReader(ReaderStream::new(r))
42			}
43			super::Inner::AsyncBytesStreamer(s) => Inner::AsyncBytesStreamer(s),
44		};
45
46		Self {
47			inner: ConstrainedAsyncBytesStreamer::new(inner, constraints),
48		}
49	}
50}
51
52impl Stream for BodyAsyncBytesStreamer {
53	type Item = io::Result<Bytes>;
54
55	fn poll_next(
56		self: Pin<&mut Self>,
57		cx: &mut Context,
58	) -> Poll<Option<io::Result<Bytes>>> {
59		self.project().inner.poll_next(cx)
60	}
61}
62
63const DEFAULT_CAP: usize = 4096;
64
65enum Inner {
66	Empty,
67	Bytes(Bytes),
68	Hyper(HyperBodyAsAsyncBytesStream),
69	SyncReader {
70		reader: BoxedSyncRead,
71		buf: BytesMut,
72	},
73	AsyncReader(ReaderStream<PinnedAsyncRead>),
74	AsyncBytesStreamer(PinnedAsyncBytesStream),
75}
76
77impl Stream for Inner {
78	type Item = io::Result<Bytes>;
79
80	fn poll_next(
81		self: Pin<&mut Self>,
82		cx: &mut Context,
83	) -> Poll<Option<io::Result<Bytes>>> {
84		let me = self.get_mut();
85
86		match me {
87			Self::Empty => Poll::Ready(None),
88			Self::Bytes(b) => {
89				let bytes = mem::take(b);
90				*me = Self::Empty;
91				Poll::Ready(Some(Ok(bytes)))
92			}
93			Self::Hyper(i) => Pin::new(i).poll_next(cx),
94			Self::SyncReader { reader, buf } => {
95				if buf.len() == 0 {
96					*buf = BytesMut::zeroed(DEFAULT_CAP);
97				}
98
99				// todo make this non blocking
100
101				let read = match reader.read(buf) {
102					Ok(r) => r,
103					Err(e) => return Poll::Ready(Some(Err(e))),
104				};
105
106				Poll::Ready(Some(Ok(buf.split_to(read).into())))
107			}
108			Self::AsyncReader(s) => Pin::new(s).poll_next(cx),
109			Self::AsyncBytesStreamer(s) => Pin::new(s).poll_next(cx),
110		}
111	}
112}
113
114pin_project! {
115	pub(super) struct ConstrainedAsyncBytesStreamer<S> {
116		#[pin]
117		inner: S,
118		#[pin]
119		timeout: Option<Sleep>,
120		size_limit: Option<usize>
121	}
122}
123
124impl<S> ConstrainedAsyncBytesStreamer<S> {
125	pub fn new(streamer: S, constraints: Constraints) -> Self {
126		Self {
127			inner: streamer,
128			timeout: constraints.timeout.map(tokio::time::sleep),
129			size_limit: constraints.size,
130		}
131	}
132}
133
134impl<S> Stream for ConstrainedAsyncBytesStreamer<S>
135where
136	S: Stream<Item = io::Result<Bytes>>,
137{
138	type Item = io::Result<Bytes>;
139
140	fn poll_next(
141		self: Pin<&mut Self>,
142		cx: &mut Context,
143	) -> Poll<Option<io::Result<Bytes>>> {
144		let mut me = self.project();
145
146		if let Poll::Ready(r) = me.inner.poll_next(cx) {
147			let bytes = match r {
148				Some(Ok(b)) => b,
149				Some(Err(e)) => return Poll::Ready(Some(Err(e))),
150				None => return Poll::Ready(None),
151			};
152
153			// validate size_limit
154			if let Some(size_limit) = &mut me.size_limit {
155				match size_limit.checked_sub(bytes.len()) {
156					Some(ns) => *size_limit = ns,
157					None => {
158						return Poll::Ready(Some(Err(size_limit_reached(
159							"async bytes streamer to big",
160						))))
161					}
162				}
163			}
164
165			return Poll::Ready(Some(Ok(bytes)));
166		}
167
168		// pending
169		if let Some(timeout) = Option::as_pin_mut(me.timeout) {
170			if let Poll::Ready(_) = timeout.poll(cx) {
171				return Poll::Ready(Some(Err(timed_out(
172					"async bytes streamer took to long",
173				))));
174			}
175		}
176
177		Poll::Pending
178	}
179}
180
181pub(super) async fn async_bytes_streamer_into_bytes(
182	s: impl Stream<Item = io::Result<Bytes>>,
183	constraints: Constraints,
184) -> io::Result<Bytes> {
185	let stream = ConstrainedAsyncBytesStreamer::new(s, constraints);
186	tokio::pin!(stream);
187
188	let mut v = BytesMut::new();
189	while let Some(bytes) = stream.next().await {
190		let bytes = bytes?;
191		v.extend(bytes);
192	}
193
194	Ok(v.into())
195}