form_data/
async.rs

1use std::{
2    error::Error as StdError,
3    fs::File,
4    io::Write,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use bytes::{Bytes, BytesMut};
10use futures_util::{
11    io::{self, AsyncRead, AsyncWrite, AsyncWriteExt},
12    stream::{Stream, TryStreamExt},
13};
14use http::{
15    header::{CONTENT_DISPOSITION, CONTENT_TYPE},
16    HeaderValue,
17};
18use tracing::trace;
19
20use crate::{
21    utils::{parse_content_disposition, parse_content_type, parse_part_headers},
22    Error, Field, Flag, FormData, Result, State,
23};
24
25impl<T, B, E> Stream for State<T>
26where
27    T: Stream<Item = Result<B, E>> + Unpin,
28    B: Into<Bytes>,
29    E: Into<Box<dyn StdError + Send + Sync>>,
30{
31    type Item = Result<Bytes>;
32
33    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
34        loop {
35            if self.is_readable {
36                // part
37                trace!("attempting to decode a part");
38
39                // field
40                if let Some(data) = self.decode() {
41                    trace!("part decoded from buffer");
42                    return Poll::Ready(Some(Ok(data)));
43                }
44
45                // field stream is ended
46                if Flag::Next == self.flag {
47                    return Poll::Ready(None);
48                }
49
50                // whole stream is ended
51                if Flag::Eof == self.flag {
52                    self.length -= self.buffer.len() as u64;
53                    self.buffer.clear();
54                    self.eof = true;
55                    return Poll::Ready(None);
56                }
57
58                self.is_readable = false;
59            }
60
61            trace!("polling data from stream");
62
63            if self.eof {
64                self.is_readable = true;
65                continue;
66            }
67
68            self.buffer.reserve(1);
69            let bytect = match Pin::new(self.io_mut()).poll_next(cx) {
70                Poll::Pending => {
71                    return Poll::Pending;
72                }
73                Poll::Ready(Some(Ok(b))) => {
74                    let b = b.into();
75                    let l = b.len() as u64;
76
77                    if let Some(max) = self.limits.checked_stream_size(self.length + l) {
78                        return Poll::Ready(Some(Err(Error::PayloadTooLarge(max))));
79                    }
80
81                    self.buffer.extend_from_slice(&b);
82                    self.length += l;
83                    l
84                }
85                Poll::Ready(Some(Err(e))) => {
86                    return Poll::Ready(Some(Err(Error::BoxError(e.into()))))
87                }
88                Poll::Ready(None) => 0,
89            };
90
91            if bytect == 0 {
92                self.eof = true;
93            }
94
95            self.is_readable = true;
96        }
97    }
98}
99
100impl<T, B, E> Field<T>
101where
102    T: Stream<Item = Result<B, E>> + Unpin,
103    B: Into<Bytes>,
104    E: Into<Box<dyn StdError + Send + Sync>>,
105{
106    /// Reads field data to bytes.
107    pub async fn bytes(&mut self) -> Result<Bytes> {
108        let mut bytes = BytesMut::new();
109        while let Some(buf) = self.try_next().await? {
110            bytes.extend_from_slice(&buf);
111        }
112        Ok(bytes.freeze())
113    }
114
115    /// Copys large buffer to `AsyncRead`, hyper can support large buffer,
116    /// 8KB <= buffer <= 512KB, so if we want to handle large buffer.
117    /// `Form::set_max_buf_size(512 * 1024);`
118    /// 3~4x performance improvement over the 8KB limitation of `AsyncRead`.
119    pub async fn copy_to<W>(&mut self, writer: &mut W) -> Result<u64>
120    where
121        W: AsyncWrite + Send + Unpin + 'static,
122    {
123        let mut n = 0;
124        while let Some(buf) = self.try_next().await? {
125            writer.write_all(&buf).await?;
126            n += buf.len();
127        }
128        writer.flush().await?;
129        Ok(n as u64)
130    }
131
132    /// Copys large buffer to File, hyper can support large buffer,
133    /// 8KB <= buffer <= 512KB, so if we want to handle large buffer.
134    /// `Form::set_max_buf_size(512 * 1024);`
135    /// 4x+ performance improvement over the 8KB limitation of `AsyncRead`.
136    pub async fn copy_to_file(&mut self, file: &mut File) -> Result<u64> {
137        let mut n = 0;
138        while let Some(buf) = self.try_next().await? {
139            n += file.write(&buf)?;
140        }
141        file.flush()?;
142        Ok(n as u64)
143    }
144
145    /// Ignores current field data, pass it.
146    pub async fn ignore(&mut self) -> Result<()> {
147        while let Some(buf) = self.try_next().await? {
148            drop(buf);
149        }
150        Ok(())
151    }
152}
153
154/// Reads payload data from part, then puts them to anywhere
155impl<T, B, E> AsyncRead for Field<T>
156where
157    T: Stream<Item = Result<B, E>> + Unpin,
158    B: Into<Bytes>,
159    E: Into<Box<dyn StdError + Send + Sync>>,
160{
161    fn poll_read(
162        self: Pin<&mut Self>,
163        cx: &mut Context<'_>,
164        mut buf: &mut [u8],
165    ) -> Poll<io::Result<usize>> {
166        match self.poll_next(cx) {
167            Poll::Pending => Poll::Pending,
168            Poll::Ready(None) => Poll::Ready(Ok(0)),
169            Poll::Ready(Some(Ok(b))) => Poll::Ready(Ok(buf.write(&b)?)),
170            Poll::Ready(Some(Err(e))) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
171        }
172    }
173}
174
175/// Reads payload data from part, then yields them
176impl<T, B, E> Stream for Field<T>
177where
178    T: Stream<Item = Result<B, E>> + Unpin,
179    B: Into<Bytes>,
180    E: Into<Box<dyn StdError + Send + Sync>>,
181{
182    type Item = Result<Bytes>;
183
184    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
185        trace!("polling {} {}", self.index, self.state.is_some());
186
187        let Some(state) = self.state.clone() else {
188            return Poll::Ready(None);
189        };
190
191        let is_file = self.filename.is_some();
192        let mut state = state
193            .try_lock()
194            .map_err(|e| Error::TryLockError(e.to_string()))?;
195
196        match Pin::new(&mut *state).poll_next(cx)? {
197            Poll::Pending => Poll::Pending,
198            Poll::Ready(res) => match res {
199                None => {
200                    if let Some(waker) = state.waker_mut().take() {
201                        waker.wake();
202                    }
203                    trace!("polled {}", self.index);
204                    drop(self.state.take());
205                    Poll::Ready(None)
206                }
207                Some(buf) => {
208                    let l = buf.len();
209
210                    if is_file {
211                        if let Some(max) = state.limits.checked_file_size(self.length + l) {
212                            return Poll::Ready(Some(Err(Error::FileTooLarge(max))));
213                        }
214                    } else if let Some(max) = state.limits.checked_field_size(self.length + l) {
215                        return Poll::Ready(Some(Err(Error::FieldTooLarge(max))));
216                    }
217
218                    self.length += l;
219                    trace!("polled bytes {}/{}", buf.len(), self.length);
220                    Poll::Ready(Some(Ok(buf)))
221                }
222            },
223        }
224    }
225}
226
227/// Reads form-data from request payload body, then yields `Field`
228impl<T, B, E> Stream for FormData<T>
229where
230    T: Stream<Item = Result<B, E>> + Unpin,
231    B: Into<Bytes>,
232    E: Into<Box<dyn StdError + Send + Sync>>,
233{
234    type Item = Result<Field<T>>;
235
236    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
237        let mut state = self
238            .state
239            .try_lock()
240            .map_err(|e| Error::TryLockError(e.to_string()))?;
241
242        if state.waker().is_some() {
243            return Poll::Pending;
244        }
245
246        match Pin::new(&mut *state).poll_next(cx)? {
247            Poll::Pending => Poll::Pending,
248            Poll::Ready(res) => match res {
249                None => {
250                    trace!("parse eof");
251                    Poll::Ready(None)
252                }
253                Some(buf) => {
254                    trace!("parse part");
255
256                    // too many parts
257                    if let Some(max) = state.limits.checked_parts(state.total + 1) {
258                        return Poll::Ready(Some(Err(Error::PartsTooMany(max))));
259                    }
260
261                    // invalid part header
262                    let Ok(mut headers) = parse_part_headers(&buf) else {
263                        return Poll::Ready(Some(Err(Error::InvalidHeader)));
264                    };
265
266                    // invalid content disposition
267                    let Some((name, filename)) = headers
268                        .remove(CONTENT_DISPOSITION)
269                        .as_ref()
270                        .map(HeaderValue::as_bytes)
271                        .map(parse_content_disposition)
272                        .and_then(Result::ok)
273                    else {
274                        return Poll::Ready(Some(Err(Error::InvalidContentDisposition)));
275                    };
276
277                    // field name is too long
278                    if let Some(max) = state.limits.checked_field_name_size(name.len()) {
279                        return Poll::Ready(Some(Err(Error::FieldNameTooLong(max))));
280                    }
281
282                    if filename.is_some() {
283                        // files too many
284                        if let Some(max) = state.limits.checked_files(state.files + 1) {
285                            return Poll::Ready(Some(Err(Error::FilesTooMany(max))));
286                        }
287                        state.files += 1;
288                    } else {
289                        // fields too many
290                        if let Some(max) = state.limits.checked_fields(state.fields + 1) {
291                            return Poll::Ready(Some(Err(Error::FieldsTooMany(max))));
292                        }
293                        state.fields += 1;
294                    }
295
296                    // yields `Field`
297                    let mut field = Field::empty();
298
299                    field.name = name;
300                    field.filename = filename;
301                    field.index = state.index();
302                    field.content_type = parse_content_type(headers.remove(CONTENT_TYPE).as_ref());
303                    field.state_mut().replace(self.state());
304
305                    if !headers.is_empty() {
306                        field.headers_mut().replace(headers);
307                    }
308
309                    // clone waker, if field is polled data, wake it.
310                    state.waker_mut().replace(cx.waker().clone());
311
312                    Poll::Ready(Some(Ok(field)))
313                }
314            },
315        }
316    }
317}