chuchi_core/body/
async_reader.rs1use super::{
2 size_limit_reached, timed_out, BoxedSyncRead, Constraints,
3 HyperBodyAsAsyncBytesStream, PinnedAsyncBytesStream, PinnedAsyncRead,
4};
5
6use std::future::Future;
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
12use tokio::time::Sleep;
13use tokio_util::io::StreamReader;
14
15use bytes::Bytes;
16use pin_project_lite::pin_project;
17
18pin_project! {
19 pub struct BodyAsyncReader {
20 #[pin]
21 reader: ConstrainedAsyncReader<Inner>
22 }
23}
24
25impl BodyAsyncReader {
26 pub(super) fn new(inner: super::Inner, constraints: Constraints) -> Self {
27 let inner = match inner {
28 super::Inner::Empty => Inner::Bytes(Bytes::new()),
29 super::Inner::Bytes(b) => Inner::Bytes(b),
30 super::Inner::Hyper(i) => Inner::Hyper(StreamReader::new(
31 HyperBodyAsAsyncBytesStream::new(i),
32 )),
33 super::Inner::SyncReader(r) => Inner::SyncReader(r),
34 super::Inner::AsyncReader(r) => Inner::AsyncReader(r),
35 super::Inner::AsyncBytesStreamer(s) => {
36 Inner::AsyncBytesStreamer(StreamReader::new(s))
37 }
38 };
39
40 Self {
41 reader: ConstrainedAsyncReader::new(inner, constraints),
42 }
43 }
44}
45
46impl AsyncRead for BodyAsyncReader {
47 fn poll_read(
48 self: Pin<&mut Self>,
49 cx: &mut Context,
50 buf: &mut ReadBuf,
51 ) -> Poll<io::Result<()>> {
52 let me = self.project();
53 me.reader.poll_read(cx, buf)
54 }
55}
56
57enum Inner {
58 Bytes(Bytes),
59 Hyper(StreamReader<HyperBodyAsAsyncBytesStream, Bytes>),
60 SyncReader(BoxedSyncRead),
61 AsyncReader(PinnedAsyncRead),
62 AsyncBytesStreamer(StreamReader<PinnedAsyncBytesStream, Bytes>),
63}
64
65impl AsyncRead for Inner {
66 fn poll_read(
67 self: Pin<&mut Self>,
68 cx: &mut Context,
69 buf: &mut ReadBuf,
70 ) -> Poll<io::Result<()>> {
71 let me = self.get_mut();
72
73 match me {
74 Self::Bytes(b) => {
75 if b.is_empty() {
76 return Poll::Ready(Ok(()));
77 }
78
79 let read = buf.remaining().min(b.len());
80 buf.put_slice(&b.split_to(read));
81 Poll::Ready(Ok(()))
82 }
83 Self::Hyper(i) => Pin::new(i).poll_read(cx, buf),
84 Self::SyncReader(r) => {
85 let filled = match r.read(buf.initialize_unfilled()) {
87 Ok(o) => o,
88 Err(e) => return Poll::Ready(Err(e)),
89 };
90
91 buf.advance(filled);
92
93 Poll::Ready(Ok(()))
94 }
95 Self::AsyncReader(r) => Pin::new(r).poll_read(cx, buf),
96 Self::AsyncBytesStreamer(s) => Pin::new(s).poll_read(cx, buf),
97 }
98 }
99}
100
101pin_project! {
102 pub(super) struct ConstrainedAsyncReader<R> {
103 #[pin]
104 inner: R,
105 #[pin]
106 timeout: Option<Sleep>,
107 size_limit: Option<usize>
108 }
109}
110
111impl<R> ConstrainedAsyncReader<R> {
112 pub fn new(reader: R, constraints: Constraints) -> Self {
113 Self {
114 inner: reader,
115 timeout: constraints.timeout.map(tokio::time::sleep),
116 size_limit: constraints.size,
117 }
118 }
119}
120
121impl<R: AsyncRead> AsyncRead for ConstrainedAsyncReader<R> {
122 fn poll_read(
123 self: Pin<&mut Self>,
124 cx: &mut Context,
125 buf: &mut ReadBuf,
126 ) -> Poll<io::Result<()>> {
127 let mut me = self.project();
128
129 let prev_filled = buf.filled().len();
130
131 if let Poll::Ready(r) = me.inner.poll_read(cx, buf) {
132 if let Err(e) = r {
133 return Poll::Ready(Err(e));
134 }
135
136 if let Some(size_limit) = &mut me.size_limit {
138 let read = buf.filled().len() - prev_filled;
139 match size_limit.checked_sub(read) {
140 Some(ns) => *size_limit = ns,
141 None => {
142 return Poll::Ready(Err(size_limit_reached(
143 "async reader to big",
144 )))
145 }
146 }
147 }
148
149 return Poll::Ready(Ok(()));
150 }
151
152 if let Some(timeout) = Option::as_pin_mut(me.timeout) {
154 if timeout.poll(cx).is_ready() {
155 return Poll::Ready(Err(timed_out("async reader took to long")));
156 }
157 }
158
159 Poll::Pending
160 }
161}
162
163pub(super) async fn async_reader_into_bytes(
164 r: PinnedAsyncRead,
165 constraints: Constraints,
166) -> io::Result<Bytes> {
167 let reader = ConstrainedAsyncReader::new(r, constraints);
168 tokio::pin!(reader);
169
170 let mut v = vec![];
171 reader.read_to_end(&mut v).await?;
172
173 Ok(v.into())
174}