mpart_async/
client.rs

1use bytes::{Bytes, BytesMut};
2use futures_core::Stream;
3use log::debug;
4use rand::{distributions::Alphanumeric, thread_rng, Rng};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::{collections::VecDeque, convert::Infallible};
8
9/// The main `MultipartRequest` struct for sending Multipart submissions to servers
10pub struct MultipartRequest<S> {
11    boundary: String,
12    items: VecDeque<MultipartItems<S>>,
13    state: Option<State<S>>,
14    written: usize,
15}
16
17enum State<S> {
18    WritingField(MultipartField),
19    WritingStream(MultipartStream<S>),
20    WritingStreamHeader(MultipartStream<S>),
21    Finished,
22}
23
24/// The enum for multipart items which is either a field or a stream
25pub enum MultipartItems<S> {
26    /// MultipartField variant
27    Field(MultipartField),
28    /// MultipartStream variant
29    Stream(MultipartStream<S>),
30}
31
32/// A stream which is part of a `MultipartRequest` and used to stream out bytes
33pub struct MultipartStream<S> {
34    name: String,
35    filename: String,
36    content_type: String,
37    stream: S,
38}
39
40/// A MultipartField which is part of a `MultipartRequest` and used to add a standard text field
41pub struct MultipartField {
42    name: String,
43    value: String,
44}
45
46impl<S> MultipartStream<S> {
47    /// Construct a new MultipartStream providing name, filename & content_type
48    pub fn new<I: Into<String>>(name: I, filename: I, content_type: I, stream: S) -> Self {
49        MultipartStream {
50            name: name.into(),
51            filename: filename.into(),
52            content_type: content_type.into(),
53            stream,
54        }
55    }
56
57    fn write_header(&self, boundary: &str) -> Bytes {
58        let mut buf = BytesMut::new();
59
60        buf.extend_from_slice(b"--");
61        buf.extend_from_slice(boundary.as_bytes());
62        buf.extend_from_slice(b"\r\n");
63
64        buf.extend_from_slice(b"Content-Disposition: form-data; name=\"");
65        buf.extend_from_slice(self.name.as_bytes());
66        buf.extend_from_slice(b"\"; filename=\"");
67        buf.extend_from_slice(self.filename.as_bytes());
68        buf.extend_from_slice(b"\"\r\n");
69        buf.extend_from_slice(b"Content-Type: ");
70        buf.extend_from_slice(self.content_type.as_bytes());
71        buf.extend_from_slice(b"\r\n");
72
73        buf.extend_from_slice(b"\r\n");
74
75        buf.freeze()
76    }
77}
78
79impl MultipartField {
80    /// Construct a new MultipartField given a name and value
81    pub fn new<I: Into<String>>(name: I, value: I) -> Self {
82        MultipartField {
83            name: name.into(),
84            value: value.into(),
85        }
86    }
87
88    fn get_bytes(&self, boundary: &str) -> Bytes {
89        let mut buf = BytesMut::new();
90
91        buf.extend_from_slice(b"--");
92        buf.extend_from_slice(boundary.as_bytes());
93        buf.extend_from_slice(b"\r\n");
94
95        buf.extend_from_slice(b"Content-Disposition: form-data; name=\"");
96        buf.extend_from_slice(self.name.as_bytes());
97        buf.extend_from_slice(b"\"\r\n");
98
99        buf.extend_from_slice(b"\r\n");
100
101        buf.extend_from_slice(self.value.as_bytes());
102
103        buf.extend_from_slice(b"\r\n");
104
105        buf.freeze()
106    }
107}
108
109impl<E, S> MultipartRequest<S>
110where
111    S: Stream<Item = Result<Bytes, E>> + Unpin,
112{
113    /// Construct a new MultipartRequest with a given Boundary
114    ///
115    /// If you want a boundary generated automatically, then you can use `MultipartRequest::default()`
116    pub fn new<I: Into<String>>(boundary: I) -> Self {
117        let items = VecDeque::new();
118
119        let state = None;
120
121        MultipartRequest {
122            boundary: boundary.into(),
123            items,
124            state,
125            written: 0,
126        }
127    }
128
129    fn next_item(&mut self) -> State<S> {
130        match self.items.pop_front() {
131            Some(MultipartItems::Field(new_field)) => State::WritingField(new_field),
132            Some(MultipartItems::Stream(new_stream)) => State::WritingStreamHeader(new_stream),
133            None => State::Finished,
134        }
135    }
136
137    /// Add a raw Stream to the Multipart request
138    ///
139    /// The Stream should return items of `Result<Bytes, Error>`
140    pub fn add_stream<I: Into<String>>(
141        &mut self,
142        name: I,
143        filename: I,
144        content_type: I,
145        stream: S,
146    ) {
147        let stream = MultipartStream::new(name, filename, content_type, stream);
148
149        if self.state.is_some() {
150            self.items.push_back(MultipartItems::Stream(stream));
151        } else {
152            self.state = Some(State::WritingStreamHeader(stream));
153        }
154    }
155
156    /// Add a Field to the Multipart request
157    pub fn add_field<I: Into<String>>(&mut self, name: I, value: I) {
158        let field = MultipartField::new(name, value);
159
160        if self.state.is_some() {
161            self.items.push_back(MultipartItems::Field(field));
162        } else {
163            self.state = Some(State::WritingField(field));
164        }
165    }
166
167    /// Gets the boundary for the MultipartRequest
168    ///
169    /// This is useful for supplying the `Content-Type` header
170    pub fn get_boundary(&self) -> &str {
171        &self.boundary
172    }
173
174    fn write_ending(&self) -> Bytes {
175        let mut buf = BytesMut::new();
176
177        buf.extend_from_slice(b"--");
178        buf.extend_from_slice(self.boundary.as_bytes());
179
180        buf.extend_from_slice(b"--\r\n");
181
182        buf.freeze()
183    }
184}
185
186#[cfg(feature = "filestream")]
187use crate::filestream::FileStream;
188#[cfg(feature = "filestream")]
189use std::path::PathBuf;
190
191#[cfg(feature = "filestream")]
192impl MultipartRequest<FileStream> {
193    /// Add a FileStream to a MultipartRequest given a path to a file
194    ///
195    /// This will guess the Content Type based upon the path (i.e, .jpg will be `image/jpeg`)
196    pub fn add_file<I: Into<String>, P: Into<PathBuf>>(&mut self, name: I, path: P) {
197        let buf = path.into();
198
199        let name = name.into();
200
201        let filename = buf
202            .file_name()
203            .expect("Should be a valid file")
204            .to_string_lossy()
205            .to_string();
206        let content_type = mime_guess::MimeGuess::from_path(&buf)
207            .first_or_octet_stream()
208            .to_string();
209        let stream = FileStream::new(buf);
210
211        self.add_stream(name, filename, content_type, stream);
212    }
213}
214
215impl<E, S> Default for MultipartRequest<S>
216where
217    S: Stream<Item = Result<Bytes, E>> + Unpin,
218{
219    fn default() -> Self {
220        let mut rng = thread_rng();
221
222        let boundary: String = (&mut rng)
223            .sample_iter(Alphanumeric)
224            .take(60)
225            .map(char::from)
226            .collect();
227
228        let items = VecDeque::new();
229
230        let state = None;
231
232        MultipartRequest {
233            boundary,
234            items,
235            state,
236            written: 0,
237        }
238    }
239}
240
241impl<E, S: Stream> Stream for MultipartRequest<S>
242where
243    S: Stream<Item = Result<Bytes, E>> + Unpin,
244{
245    type Item = Result<Bytes, E>;
246
247    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
248        debug!("Poll hit");
249
250        let self_ref = self.get_mut();
251
252        let mut bytes = None;
253
254        let mut new_state = None;
255
256        let mut waiting = false;
257
258        if let Some(state) = self_ref.state.take() {
259            match state {
260                State::WritingStreamHeader(stream) => {
261                    debug!("Writing Stream Header for:{}", &stream.filename);
262                    bytes = Some(stream.write_header(&self_ref.boundary));
263
264                    new_state = Some(State::WritingStream(stream));
265                }
266                State::WritingStream(mut stream) => {
267                    debug!("Writing Stream Body for:{}", &stream.filename);
268
269                    match Pin::new(&mut stream.stream).poll_next(cx) {
270                        Poll::Pending => {
271                            waiting = true;
272                            new_state = Some(State::WritingStream(stream));
273                        }
274                        Poll::Ready(Some(Ok(some_bytes))) => {
275                            bytes = Some(some_bytes);
276                            new_state = Some(State::WritingStream(stream));
277                        }
278                        Poll::Ready(None) => {
279                            let mut buf = BytesMut::new();
280                            /*
281                                This is a special case that we want to include \r\n and then the next item
282                            */
283                            buf.extend_from_slice(b"\r\n");
284
285                            match self_ref.next_item() {
286                                State::WritingStreamHeader(stream) => {
287                                    debug!("Writing new Stream Header");
288                                    buf.extend_from_slice(&stream.write_header(&self_ref.boundary));
289                                    new_state = Some(State::WritingStream(stream));
290                                }
291                                State::Finished => {
292                                    debug!("Writing new Stream Finished");
293                                    buf.extend_from_slice(&self_ref.write_ending());
294                                }
295                                State::WritingField(field) => {
296                                    debug!("Writing new Stream Field");
297                                    buf.extend_from_slice(&field.get_bytes(&self_ref.boundary));
298                                    new_state = Some(self_ref.next_item());
299                                }
300                                _ => (),
301                            }
302
303                            bytes = Some(buf.freeze())
304                        }
305                        an_error @ Poll::Ready(Some(Err(_))) => return an_error,
306                    }
307                }
308                State::Finished => {
309                    debug!("Writing Stream Finished");
310                    bytes = Some(self_ref.write_ending());
311                }
312                State::WritingField(field) => {
313                    debug!("Writing Field: {}", &field.name);
314                    bytes = Some(field.get_bytes(&self_ref.boundary));
315                    new_state = Some(self_ref.next_item());
316                }
317            }
318        }
319
320        if let Some(state) = new_state {
321            self_ref.state = Some(state);
322        }
323
324        if waiting {
325            return Poll::Pending;
326        }
327
328        if let Some(ref bytes) = bytes {
329            debug!("Bytes: {}", bytes.len());
330            self_ref.written += bytes.len();
331        } else {
332            debug!(
333                "No bytes to write, finished stream, total bytes:{}",
334                self_ref.written
335            );
336        }
337
338        Poll::Ready(bytes.map(|bytes| Ok(bytes)))
339    }
340}
341
342/// A Simple In-Memory Stream that can be used to store bytes
343#[derive(Clone)]
344pub struct ByteStream {
345    bytes: Option<Bytes>,
346}
347
348impl ByteStream {
349    /// Create a new ByteStream based upon the byte slice (note: this will copy from the slice)
350    pub fn new(bytes: &[u8]) -> Self {
351        let mut buf = BytesMut::new();
352
353        buf.extend_from_slice(bytes);
354
355        ByteStream {
356            bytes: Some(buf.freeze()),
357        }
358    }
359}
360
361impl Stream for ByteStream {
362    type Item = Result<Bytes, Infallible>;
363
364    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
365        Poll::Ready(self.as_mut().bytes.take().map(Ok))
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use futures_util::StreamExt;
373
374    #[test]
375    fn sets_boundary() {
376        let req: MultipartRequest<ByteStream> = MultipartRequest::new("AaB03x");
377        assert_eq!(req.get_boundary(), "AaB03x");
378    }
379
380    #[test]
381    fn writes_field_header() {
382        let field = MultipartField::new("field_name", "field_value");
383
384        let input: &[u8] = b"--AaB03x\r\n\
385                Content-Disposition: form-data; name=\"field_name\"\r\n\
386                \r\n\
387                field_value\r\n";
388
389        let bytes = field.get_bytes("AaB03x");
390
391        assert_eq!(&bytes[..], input);
392    }
393
394    #[test]
395    fn writes_stream_header() {
396        let stream = MultipartStream::new("file", "test.txt", "text/plain", ByteStream::new(b""));
397
398        let input: &[u8] = b"--AaB03x\r\n\
399                Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\
400                Content-Type: text/plain\r\n\
401                \r\n";
402
403        let bytes = stream.write_header("AaB03x");
404
405        assert_eq!(&bytes[..], input);
406    }
407
408    #[tokio::test]
409    async fn writes_fields() {
410        let mut req: MultipartRequest<ByteStream> = MultipartRequest::new("AaB03x");
411
412        req.add_field("name1", "value1");
413        req.add_field("name2", "value2");
414
415        let input: &[u8] = b"--AaB03x\r\n\
416                Content-Disposition: form-data; name=\"name1\"\r\n\
417                \r\n\
418                value1\r\n\
419                --AaB03x\r\n\
420                Content-Disposition: form-data; name=\"name2\"\r\n\
421                \r\n\
422                value2\r\n\
423                --AaB03x--\r\n";
424
425        let output = req
426            .fold(BytesMut::new(), |mut buf, result| async {
427                if let Ok(bytes) = result {
428                    buf.extend_from_slice(&bytes);
429                };
430
431                buf
432            })
433            .await;
434
435        assert_eq!(&output[..], input);
436    }
437
438    #[tokio::test]
439    async fn writes_streams() {
440        let mut req: MultipartRequest<ByteStream> = MultipartRequest::new("AaB03x");
441
442        let stream = ByteStream::new(b"Lorem Ipsum\n");
443
444        req.add_stream("file", "test.txt", "text/plain", stream);
445
446        let input: &[u8] = b"--AaB03x\r\n\
447                Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\
448                Content-Type: text/plain\r\n\
449                \r\n\
450                Lorem Ipsum\n\r\n\
451                --AaB03x--\r\n";
452
453        let output = req
454            .fold(BytesMut::new(), |mut buf, result| async {
455                if let Ok(bytes) = result {
456                    buf.extend_from_slice(&bytes);
457                };
458
459                buf
460            })
461            .await;
462
463        assert_eq!(&output[..], input);
464    }
465
466    #[tokio::test]
467    async fn writes_streams_and_fields() {
468        let mut req: MultipartRequest<ByteStream> = MultipartRequest::new("AaB03x");
469
470        let stream = ByteStream::new(b"Lorem Ipsum\n");
471
472        req.add_stream("file", "text.txt", "text/plain", stream);
473        req.add_field("name1", "value1");
474        req.add_field("name2", "value2");
475
476        let input: &[u8] = b"--AaB03x\r\n\
477                Content-Disposition: form-data; name=\"file\"; filename=\"text.txt\"\r\n\
478                Content-Type: text/plain\r\n\
479                \r\n\
480                Lorem Ipsum\n\r\n\
481                --AaB03x\r\n\
482                Content-Disposition: form-data; name=\"name1\"\r\n\
483                \r\n\
484                value1\r\n\
485                --AaB03x\r\n\
486                Content-Disposition: form-data; name=\"name2\"\r\n\
487                \r\n\
488                value2\r\n\
489                --AaB03x--\r\n";
490
491        let output = req
492            .fold(BytesMut::new(), |mut buf, result| async {
493                if let Ok(bytes) = result {
494                    buf.extend_from_slice(&bytes);
495                };
496
497                buf
498            })
499            .await;
500
501        assert_eq!(&output[..], input);
502    }
503}