1use std::{
2 error::Error as StdError,
3 fs::File,
4 io::Write,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use bytes::{Bytes, BytesMut};
10use futures_util::{
11 io::{self, AsyncRead, AsyncWrite, AsyncWriteExt},
12 stream::{Stream, TryStreamExt},
13};
14use http::{
15 header::{CONTENT_DISPOSITION, CONTENT_TYPE},
16 HeaderValue,
17};
18use tracing::trace;
19
20use crate::{
21 utils::{parse_content_disposition, parse_content_type, parse_part_headers},
22 Error, Field, Flag, FormData, Result, State,
23};
24
25impl<T, B, E> Stream for State<T>
26where
27 T: Stream<Item = Result<B, E>> + Unpin,
28 B: Into<Bytes>,
29 E: Into<Box<dyn StdError + Send + Sync>>,
30{
31 type Item = Result<Bytes>;
32
33 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
34 loop {
35 if self.is_readable {
36 trace!("attempting to decode a part");
38
39 if let Some(data) = self.decode() {
41 trace!("part decoded from buffer");
42 return Poll::Ready(Some(Ok(data)));
43 }
44
45 if Flag::Next == self.flag {
47 return Poll::Ready(None);
48 }
49
50 if Flag::Eof == self.flag {
52 self.length -= self.buffer.len() as u64;
53 self.buffer.clear();
54 self.eof = true;
55 return Poll::Ready(None);
56 }
57
58 self.is_readable = false;
59 }
60
61 trace!("polling data from stream");
62
63 if self.eof {
64 self.is_readable = true;
65 continue;
66 }
67
68 self.buffer.reserve(1);
69 let bytect = match Pin::new(self.io_mut()).poll_next(cx) {
70 Poll::Pending => {
71 return Poll::Pending;
72 }
73 Poll::Ready(Some(Ok(b))) => {
74 let b = b.into();
75 let l = b.len() as u64;
76
77 if let Some(max) = self.limits.checked_stream_size(self.length + l) {
78 return Poll::Ready(Some(Err(Error::PayloadTooLarge(max))));
79 }
80
81 self.buffer.extend_from_slice(&b);
82 self.length += l;
83 l
84 }
85 Poll::Ready(Some(Err(e))) => {
86 return Poll::Ready(Some(Err(Error::BoxError(e.into()))))
87 }
88 Poll::Ready(None) => 0,
89 };
90
91 if bytect == 0 {
92 self.eof = true;
93 }
94
95 self.is_readable = true;
96 }
97 }
98}
99
100impl<T, B, E> Field<T>
101where
102 T: Stream<Item = Result<B, E>> + Unpin,
103 B: Into<Bytes>,
104 E: Into<Box<dyn StdError + Send + Sync>>,
105{
106 pub async fn bytes(&mut self) -> Result<Bytes> {
108 let mut bytes = BytesMut::new();
109 while let Some(buf) = self.try_next().await? {
110 bytes.extend_from_slice(&buf);
111 }
112 Ok(bytes.freeze())
113 }
114
115 pub async fn copy_to<W>(&mut self, writer: &mut W) -> Result<u64>
120 where
121 W: AsyncWrite + Send + Unpin + 'static,
122 {
123 let mut n = 0;
124 while let Some(buf) = self.try_next().await? {
125 writer.write_all(&buf).await?;
126 n += buf.len();
127 }
128 writer.flush().await?;
129 Ok(n as u64)
130 }
131
132 pub async fn copy_to_file(&mut self, file: &mut File) -> Result<u64> {
137 let mut n = 0;
138 while let Some(buf) = self.try_next().await? {
139 n += file.write(&buf)?;
140 }
141 file.flush()?;
142 Ok(n as u64)
143 }
144
145 pub async fn ignore(&mut self) -> Result<()> {
147 while let Some(buf) = self.try_next().await? {
148 drop(buf);
149 }
150 Ok(())
151 }
152}
153
154impl<T, B, E> AsyncRead for Field<T>
156where
157 T: Stream<Item = Result<B, E>> + Unpin,
158 B: Into<Bytes>,
159 E: Into<Box<dyn StdError + Send + Sync>>,
160{
161 fn poll_read(
162 self: Pin<&mut Self>,
163 cx: &mut Context<'_>,
164 mut buf: &mut [u8],
165 ) -> Poll<io::Result<usize>> {
166 match self.poll_next(cx) {
167 Poll::Pending => Poll::Pending,
168 Poll::Ready(None) => Poll::Ready(Ok(0)),
169 Poll::Ready(Some(Ok(b))) => Poll::Ready(Ok(buf.write(&b)?)),
170 Poll::Ready(Some(Err(e))) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
171 }
172 }
173}
174
175impl<T, B, E> Stream for Field<T>
177where
178 T: Stream<Item = Result<B, E>> + Unpin,
179 B: Into<Bytes>,
180 E: Into<Box<dyn StdError + Send + Sync>>,
181{
182 type Item = Result<Bytes>;
183
184 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
185 trace!("polling {} {}", self.index, self.state.is_some());
186
187 let Some(state) = self.state.clone() else {
188 return Poll::Ready(None);
189 };
190
191 let is_file = self.filename.is_some();
192 let mut state = state
193 .try_lock()
194 .map_err(|e| Error::TryLockError(e.to_string()))?;
195
196 match Pin::new(&mut *state).poll_next(cx)? {
197 Poll::Pending => Poll::Pending,
198 Poll::Ready(res) => match res {
199 None => {
200 if let Some(waker) = state.waker_mut().take() {
201 waker.wake();
202 }
203 trace!("polled {}", self.index);
204 drop(self.state.take());
205 Poll::Ready(None)
206 }
207 Some(buf) => {
208 let l = buf.len();
209
210 if is_file {
211 if let Some(max) = state.limits.checked_file_size(self.length + l) {
212 return Poll::Ready(Some(Err(Error::FileTooLarge(max))));
213 }
214 } else if let Some(max) = state.limits.checked_field_size(self.length + l) {
215 return Poll::Ready(Some(Err(Error::FieldTooLarge(max))));
216 }
217
218 self.length += l;
219 trace!("polled bytes {}/{}", buf.len(), self.length);
220 Poll::Ready(Some(Ok(buf)))
221 }
222 },
223 }
224 }
225}
226
227impl<T, B, E> Stream for FormData<T>
229where
230 T: Stream<Item = Result<B, E>> + Unpin,
231 B: Into<Bytes>,
232 E: Into<Box<dyn StdError + Send + Sync>>,
233{
234 type Item = Result<Field<T>>;
235
236 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
237 let mut state = self
238 .state
239 .try_lock()
240 .map_err(|e| Error::TryLockError(e.to_string()))?;
241
242 if state.waker().is_some() {
243 return Poll::Pending;
244 }
245
246 match Pin::new(&mut *state).poll_next(cx)? {
247 Poll::Pending => Poll::Pending,
248 Poll::Ready(res) => match res {
249 None => {
250 trace!("parse eof");
251 Poll::Ready(None)
252 }
253 Some(buf) => {
254 trace!("parse part");
255
256 if let Some(max) = state.limits.checked_parts(state.total + 1) {
258 return Poll::Ready(Some(Err(Error::PartsTooMany(max))));
259 }
260
261 let Ok(mut headers) = parse_part_headers(&buf) else {
263 return Poll::Ready(Some(Err(Error::InvalidHeader)));
264 };
265
266 let Some((name, filename)) = headers
268 .remove(CONTENT_DISPOSITION)
269 .as_ref()
270 .map(HeaderValue::as_bytes)
271 .map(parse_content_disposition)
272 .and_then(Result::ok)
273 else {
274 return Poll::Ready(Some(Err(Error::InvalidContentDisposition)));
275 };
276
277 if let Some(max) = state.limits.checked_field_name_size(name.len()) {
279 return Poll::Ready(Some(Err(Error::FieldNameTooLong(max))));
280 }
281
282 if filename.is_some() {
283 if let Some(max) = state.limits.checked_files(state.files + 1) {
285 return Poll::Ready(Some(Err(Error::FilesTooMany(max))));
286 }
287 state.files += 1;
288 } else {
289 if let Some(max) = state.limits.checked_fields(state.fields + 1) {
291 return Poll::Ready(Some(Err(Error::FieldsTooMany(max))));
292 }
293 state.fields += 1;
294 }
295
296 let mut field = Field::empty();
298
299 field.name = name;
300 field.filename = filename;
301 field.index = state.index();
302 field.content_type = parse_content_type(headers.remove(CONTENT_TYPE).as_ref());
303 field.state_mut().replace(self.state());
304
305 if !headers.is_empty() {
306 field.headers_mut().replace(headers);
307 }
308
309 state.waker_mut().replace(cx.waker().clone());
311
312 Poll::Ready(Some(Ok(field)))
313 }
314 },
315 }
316 }
317}