Skip to main content

http_multipart/
field.rs

1use core::{cmp, pin::Pin};
2
3use bytes::{Bytes, BytesMut};
4use futures_core::stream::Stream;
5use http::header::HeaderMap;
6use memchr::memmem;
7
8use super::{
9    content_disposition::ContentDisposition,
10    error::{MultipartError, PayloadError},
11    Multipart,
12};
13
14pub struct Field<'a, S> {
15    decoder: FieldDecoder,
16    cp: ContentDisposition,
17    multipart: Pin<&'a mut Multipart<S>>,
18}
19
20impl<S> Drop for Field<'_, S> {
21    fn drop(&mut self) {
22        self.multipart.as_mut().project().headers.clear();
23    }
24}
25
26impl<'a, S> Field<'a, S> {
27    pub(super) fn new(length: Option<u64>, cp: ContentDisposition, multipart: Pin<&'a mut Multipart<S>>) -> Self {
28        let typ = match length {
29            Some(len) => FieldDecoder::Fixed(len),
30            None => FieldDecoder::StreamBegin,
31        };
32        Self {
33            decoder: typ,
34            cp,
35            multipart,
36        }
37    }
38}
39
40#[derive(Default)]
41pub(super) enum FieldDecoder {
42    Fixed(u64),
43    #[default]
44    StreamBegin,
45    StreamDelimiter,
46    StreamEnd,
47}
48
49impl<S, T, E> Field<'_, S>
50where
51    S: Stream<Item = Result<T, E>>,
52    T: AsRef<[u8]> + 'static,
53    E: Into<PayloadError>,
54{
55    /// The field name found in the [http::header::CONTENT_DISPOSITION] header.
56    pub fn name(&self) -> Option<&str> {
57        self.cp
58            .name_from_headers(self.headers())
59            .and_then(|s| std::str::from_utf8(s).ok())
60    }
61
62    /// The file name found in the [http::header::CONTENT_DISPOSITION] header.
63    pub fn file_name(&self) -> Option<&str> {
64        self.cp
65            .filename_from_headers(self.headers())
66            .and_then(|s| std::str::from_utf8(s).ok())
67    }
68
69    pub fn headers(&self) -> &HeaderMap {
70        &self.multipart.headers
71    }
72
73    pub async fn try_next(&mut self) -> Result<Option<Bytes>, MultipartError> {
74        loop {
75            let multipart = self.multipart.as_mut().project();
76            let buf = multipart.buf;
77
78            // check multipart buffer first and drain it if possible.
79            if !buf.is_empty() {
80                match self.decoder {
81                    FieldDecoder::Fixed(0) | FieldDecoder::StreamEnd => {
82                        // next step would match on decoder and handle field eof
83                    }
84                    FieldDecoder::Fixed(ref mut len) => {
85                        let at = cmp::min(*len, buf.len() as u64);
86                        *len -= at;
87                        let chunk = buf.split_to(at as usize).freeze();
88                        return Ok(Some(chunk));
89                    }
90                    FieldDecoder::StreamBegin | FieldDecoder::StreamDelimiter => {
91                        if let Some(at) = self.decoder.try_find_split_idx(buf, multipart.boundary)? {
92                            return Ok(Some(buf.split_to(at).freeze()));
93                        }
94                    }
95                }
96            }
97
98            match &mut self.decoder {
99                // previoous try_find_split_idx may mutate the decoder state to stream end.
100                // lower the field eof here would catch all possible eof state
101                FieldDecoder::Fixed(0) | FieldDecoder::StreamEnd => {
102                    *multipart.pending_field = false;
103                    return Ok(None);
104                }
105                decoder => {
106                    // multipart buffer is empty. read more from stream.
107                    let item = self.multipart.as_mut().try_read_stream().await?;
108
109                    let multipart = self.multipart.as_mut().project();
110                    let buf = multipart.buf;
111
112                    // try to deal with the read bytes in place to reduce memory footprint.
113                    // fall back to multipart buffer memory when streaming parser can't be
114                    // determined
115                    match decoder {
116                        // len == 0 is already handled in the outter match.
117                        FieldDecoder::Fixed(len) => {
118                            let chunk = item.as_ref();
119                            let at = cmp::min(*len, chunk.len() as u64);
120                            *len -= at;
121                            let bytes = split_bytes(item, at as usize, buf);
122                            return Ok(Some(bytes));
123                        }
124                        FieldDecoder::StreamBegin => {
125                            if let Some(at) = self.decoder.try_find_split_idx(&item, multipart.boundary)? {
126                                let bytes = split_bytes(item, at, buf);
127                                return Ok(Some(bytes));
128                            }
129                        }
130                        // partial delimiter bytes are already in buf; combine with new data
131                        // before searching so that a split "\r\n--" is not missed.
132                        FieldDecoder::StreamDelimiter => {}
133                        FieldDecoder::StreamEnd => unreachable!("outter match covered branch already"),
134                    };
135
136                    buf.extend_from_slice(item.as_ref());
137                }
138            }
139        }
140    }
141}
142
143impl FieldDecoder {
144    pub(super) fn try_find_split_idx<T>(&mut self, item: &T, boundary: &[u8]) -> Result<Option<usize>, MultipartError>
145    where
146        T: AsRef<[u8]>,
147    {
148        let item = item.as_ref();
149
150        match memmem::find(item, super::FIELD_DELIMITER) {
151            Some(idx) => {
152                let start = idx + super::FIELD_DELIMITER.len();
153                let length = cmp::min(item.len() - start, boundary.len());
154                let slice = &item[start..start + length];
155
156                // not boundary — yield up to the end of the checked bytes.
157                if !boundary.starts_with(slice) {
158                    return Ok(Some(start + length));
159                }
160
161                // boundary prefix matched but full name not yet visible.
162                *self = if boundary.len() > slice.len() {
163                    FieldDecoder::StreamDelimiter
164                } else {
165                    FieldDecoder::StreamEnd
166                };
167
168                // caller would split a byte buffer based on returned idx
169                // split at 0 is wasteful
170                Ok((idx > 0).then_some(idx))
171            }
172            None => {
173                // No "\r\n--" in this item. check whether the tail is a partial prefix
174                // of "\r\n--" so those bytes can be kept in the buffer for the next
175                // read to complete the match.
176                Ok(match potential_boundary_tail(item) {
177                    Some(keep) => {
178                        *self = FieldDecoder::StreamDelimiter;
179                        (keep < item.len()).then_some(item.len() - keep)
180                    }
181                    None => {
182                        *self = FieldDecoder::StreamBegin;
183                        Some(item.len())
184                    }
185                })
186            }
187        }
188    }
189}
190
191fn potential_boundary_tail(item: &[u8]) -> Option<usize> {
192    let len = item.len();
193    item.last()?
194        .eq(&b'\r')
195        .then_some(1)
196        .or_else(|| item[len.saturating_sub(2)..].eq(b"\r\n").then_some(2))
197        .or_else(|| item[len.saturating_sub(3)..].eq(b"\r\n-").then_some(3))
198}
199
200// split chunked item bytes. determined streamable bytes are returned.
201// the rest part is extended onto multipart's internal buffer.
202fn split_bytes<T>(item: T, at: usize, buf: &mut BytesMut) -> Bytes
203where
204    T: AsRef<[u8]> + 'static,
205{
206    match try_downcast_to_bytes(item) {
207        Ok(mut item) => {
208            if item.len() == at {
209                return item;
210            }
211            let bytes = item.split_to(at);
212            buf.extend_from_slice(item.as_ref());
213            bytes
214        }
215        Err(item) => {
216            let chunk = item.as_ref();
217            let bytes = Bytes::copy_from_slice(&chunk[..at]);
218            buf.extend_from_slice(&chunk[at..]);
219            bytes
220        }
221    }
222}
223
224fn try_downcast_to_bytes<T: 'static>(item: T) -> Result<Bytes, T> {
225    use std::any::Any;
226
227    let item = &mut Some(item);
228    match (item as &mut dyn Any).downcast_mut::<Option<Bytes>>() {
229        Some(bytes) => Ok(bytes.take().unwrap()),
230        None => Err(item.take().unwrap()),
231    }
232}
233
234#[cfg(test)]
235mod test {
236    use super::*;
237
238    #[test]
239    fn downcast_bytes() {
240        let bytes = Bytes::new();
241        assert!(try_downcast_to_bytes(bytes).is_ok());
242        let bytes = Vec::<u8>::new();
243        assert!(try_downcast_to_bytes(bytes).is_err());
244    }
245}