chuchi_core/body/
async_reader.rs

1use super::{
2	size_limit_reached, timed_out, BoxedSyncRead, Constraints,
3	HyperBodyAsAsyncBytesStream, PinnedAsyncBytesStream, PinnedAsyncRead,
4};
5
6use std::future::Future;
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
12use tokio::time::Sleep;
13use tokio_util::io::StreamReader;
14
15use bytes::Bytes;
16use pin_project_lite::pin_project;
17
18pin_project! {
19	pub struct BodyAsyncReader {
20		#[pin]
21		reader: ConstrainedAsyncReader<Inner>
22	}
23}
24
25impl BodyAsyncReader {
26	pub(super) fn new(inner: super::Inner, constraints: Constraints) -> Self {
27		let inner = match inner {
28			super::Inner::Empty => Inner::Bytes(Bytes::new()),
29			super::Inner::Bytes(b) => Inner::Bytes(b),
30			super::Inner::Hyper(i) => Inner::Hyper(StreamReader::new(
31				HyperBodyAsAsyncBytesStream::new(i),
32			)),
33			super::Inner::SyncReader(r) => Inner::SyncReader(r),
34			super::Inner::AsyncReader(r) => Inner::AsyncReader(r),
35			super::Inner::AsyncBytesStreamer(s) => {
36				Inner::AsyncBytesStreamer(StreamReader::new(s))
37			}
38		};
39
40		Self {
41			reader: ConstrainedAsyncReader::new(inner, constraints),
42		}
43	}
44}
45
46impl AsyncRead for BodyAsyncReader {
47	fn poll_read(
48		self: Pin<&mut Self>,
49		cx: &mut Context,
50		buf: &mut ReadBuf,
51	) -> Poll<io::Result<()>> {
52		let me = self.project();
53		me.reader.poll_read(cx, buf)
54	}
55}
56
57enum Inner {
58	Bytes(Bytes),
59	Hyper(StreamReader<HyperBodyAsAsyncBytesStream, Bytes>),
60	SyncReader(BoxedSyncRead),
61	AsyncReader(PinnedAsyncRead),
62	AsyncBytesStreamer(StreamReader<PinnedAsyncBytesStream, Bytes>),
63}
64
65impl AsyncRead for Inner {
66	fn poll_read(
67		self: Pin<&mut Self>,
68		cx: &mut Context,
69		buf: &mut ReadBuf,
70	) -> Poll<io::Result<()>> {
71		let me = self.get_mut();
72
73		match me {
74			Self::Bytes(b) => {
75				if b.is_empty() {
76					return Poll::Ready(Ok(()));
77				}
78
79				let read = buf.remaining().min(b.len());
80				buf.put_slice(&b.split_to(read));
81				Poll::Ready(Ok(()))
82			}
83			Self::Hyper(i) => Pin::new(i).poll_read(cx, buf),
84			Self::SyncReader(r) => {
85				// todo implement this without blocking the current thread
86				let filled = match r.read(buf.initialize_unfilled()) {
87					Ok(o) => o,
88					Err(e) => return Poll::Ready(Err(e)),
89				};
90
91				buf.advance(filled);
92
93				Poll::Ready(Ok(()))
94			}
95			Self::AsyncReader(r) => Pin::new(r).poll_read(cx, buf),
96			Self::AsyncBytesStreamer(s) => Pin::new(s).poll_read(cx, buf),
97		}
98	}
99}
100
101pin_project! {
102	pub(super) struct ConstrainedAsyncReader<R> {
103		#[pin]
104		inner: R,
105		#[pin]
106		timeout: Option<Sleep>,
107		size_limit: Option<usize>
108	}
109}
110
111impl<R> ConstrainedAsyncReader<R> {
112	pub fn new(reader: R, constraints: Constraints) -> Self {
113		Self {
114			inner: reader,
115			timeout: constraints.timeout.map(tokio::time::sleep),
116			size_limit: constraints.size,
117		}
118	}
119}
120
121impl<R: AsyncRead> AsyncRead for ConstrainedAsyncReader<R> {
122	fn poll_read(
123		self: Pin<&mut Self>,
124		cx: &mut Context,
125		buf: &mut ReadBuf,
126	) -> Poll<io::Result<()>> {
127		let mut me = self.project();
128
129		let prev_filled = buf.filled().len();
130
131		if let Poll::Ready(r) = me.inner.poll_read(cx, buf) {
132			if let Err(e) = r {
133				return Poll::Ready(Err(e));
134			}
135
136			// validate size_limit
137			if let Some(size_limit) = &mut me.size_limit {
138				let read = buf.filled().len() - prev_filled;
139				match size_limit.checked_sub(read) {
140					Some(ns) => *size_limit = ns,
141					None => {
142						return Poll::Ready(Err(size_limit_reached(
143							"async reader to big",
144						)))
145					}
146				}
147			}
148
149			return Poll::Ready(Ok(()));
150		}
151
152		// pending
153		if let Some(timeout) = Option::as_pin_mut(me.timeout) {
154			if timeout.poll(cx).is_ready() {
155				return Poll::Ready(Err(timed_out("async reader took to long")));
156			}
157		}
158
159		Poll::Pending
160	}
161}
162
163pub(super) async fn async_reader_into_bytes(
164	r: PinnedAsyncRead,
165	constraints: Constraints,
166) -> io::Result<Bytes> {
167	let reader = ConstrainedAsyncReader::new(r, constraints);
168	tokio::pin!(reader);
169
170	let mut v = vec![];
171	reader.read_to_end(&mut v).await?;
172
173	Ok(v.into())
174}