rocket_multipart/
writer.rs

1use std::{
2    io::{self, Cursor},
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use rocket::{
8    futures::Stream,
9    http::{ContentType, Header, HeaderMap, Status},
10    response::{Responder, Result},
11    tokio::io::{AsyncRead, ReadBuf},
12    Request, Response,
13};
14
15/// A single section to be returned in a stream
16pub struct MultipartSection<'r> {
17    headers: HeaderMap<'static>,
18    content: Pin<Box<dyn AsyncRead + Send + 'r>>,
19}
20
21impl<'r> MultipartSection<'r> {
22    /// Construct a new MultipartSection from an async reader.
23    ///
24    /// If the readers is already in a `Box`, use [`Self::from_box`]
25    pub fn new<T: AsyncRead + Send + 'r>(reader: T) -> Self {
26        Self {
27            headers: HeaderMap::new(),
28            content: Box::pin(reader),
29        }
30    }
31
32    /// Construct a new MultipartSection from a Boxed async reader.
33    ///
34    /// Useful to avoid double boxing a reader.
35    pub fn from_box(reader: Box<dyn AsyncRead + Send + 'r>) -> Self {
36        Self {
37            headers: HeaderMap::new(),
38            content: Box::into_pin(reader),
39        }
40    }
41
42    /// Construct a new MultipartSection from a byte slice.
43    pub fn from_slice(slice: &'r [u8]) -> Self {
44        Self {
45            headers: HeaderMap::new(),
46            content: Box::pin(Cursor::new(slice)),
47        }
48    }
49
50    /// Serialize a JSON object into a MultipartSection
51    ///
52    /// Only available on `json` feature
53    #[cfg(feature = "json")]
54    pub fn from_json<T: ?Sized + serde::Serialize>(obj: &T) -> serde_json::Result<Self> {
55        let slice = serde_json::to_vec(obj)?;
56        Ok(Self {
57            headers: HeaderMap::new(),
58            content: Box::pin(Cursor::new(slice)),
59        })
60    }
61
62    /// Add a header to this section. If this section already has a header with
63    /// the same name, this method adds an additional value.
64    pub fn add_header(mut self, header: impl Into<Header<'static>>) -> Self {
65        self.headers.add(header);
66        self
67    }
68
69    /// Replaces a header for this section. If this section already has a header
70    /// with the same name, this methods replaces all values with the new value.
71    pub fn replace_header(mut self, header: impl Into<Header<'static>>) -> Self {
72        self.headers.replace(header);
73        self
74    }
75
76    fn encode_headers(&self, boundary: &str) -> String {
77        let mut s = format!("\r\n--{boundary}\r\n");
78        for h in self.headers.iter() {
79            s.push_str(h.name.as_str());
80            s.push_str(": ");
81            s.push_str(h.value());
82            s.push_str("\r\n");
83        }
84        s.push_str("\r\n");
85        s
86    }
87}
88
89/// A stream of sections to be returned as a `multipart/mixed` stream.
90pub struct MultipartStream<T> {
91    boundary: String,
92    stream: T,
93    sub_type: &'static str,
94}
95
96impl<T> MultipartStream<T> {
97    /// Construct a stream, using the specified string as a boundary marker
98    /// between stream items.
99    pub fn new(boundary: impl Into<String>, stream: T) -> Self {
100        Self {
101            boundary: boundary.into(),
102            stream,
103            sub_type: "mixed",
104        }
105    }
106
107    /// Construct a stream, generating a random 15 character (alpha-numeric)
108    /// boundary marker
109    ///
110    /// Only available on (default) `rand` feature
111    #[cfg(feature = "rand")]
112    pub fn new_random(stream: T) -> Self {
113        use rand::{distributions::Alphanumeric, Rng};
114
115        Self {
116            boundary: rand::thread_rng()
117                .sample_iter(Alphanumeric)
118                .map(|v| v as char)
119                .take(15)
120                .collect(),
121            stream,
122            sub_type: "mixed",
123        }
124    }
125
126    /// Change the ContentType sub type from the default `mixed`
127    pub fn with_subtype(mut self, sub_type: &'static str) -> Self {
128        self.sub_type = sub_type;
129        self
130    }
131}
132
133impl<'r, 'o: 'r, T: Stream<Item = MultipartSection<'o>> + Send + 'o> Responder<'r, 'o>
134    for MultipartStream<T>
135{
136    fn respond_to(self, _r: &'r Request<'_>) -> Result<'o> {
137        Response::build()
138            .status(Status::Ok)
139            .header(
140                ContentType::new("multipart", self.sub_type)
141                    .with_params(("boundary", self.boundary.clone())),
142            )
143            .streamed_body(MultipartStreamInner(
144                self.boundary,
145                self.stream,
146                StreamState::Waiting,
147            ))
148            .ok()
149    }
150}
151
152struct MultipartStreamInner<'r, T>(String, T, StreamState<'r>);
153
154impl<'r, T> MultipartStreamInner<'r, T> {
155    fn inner(self: Pin<&mut Self>) -> (&str, Pin<&mut T>, &mut StreamState<'r>) {
156        // SAFETY: We are projecting `String` and `StreamState` to simple borrows (they implement unpin, so this is fine)
157        // We project `T` to `Pin<&mut T>`, since we don't know (or care) if it implement unpin
158        let this = unsafe { self.get_unchecked_mut() };
159        (
160            &this.0,
161            unsafe { Pin::new_unchecked(&mut this.1) },
162            &mut this.2,
163        )
164    }
165}
166
167enum StreamState<'r> {
168    Waiting,
169    Header(Cursor<Vec<u8>>, Pin<Box<dyn AsyncRead + Send + 'r>>),
170    Raw(Pin<Box<dyn AsyncRead + Send + 'r>>),
171    Footer(Cursor<Vec<u8>>),
172}
173
174impl<'r, T: Stream<Item = MultipartSection<'r>> + Send + 'r> AsyncRead
175    for MultipartStreamInner<'r, T>
176{
177    fn poll_read(
178        self: Pin<&mut Self>,
179        cx: &mut Context<'_>,
180        buf: &mut ReadBuf<'_>,
181    ) -> Poll<io::Result<()>> {
182        let (boundary, mut stream, state) = self.inner();
183        loop {
184            match state {
185                StreamState::Waiting => match stream.as_mut().poll_next(cx) {
186                    Poll::Ready(Some(v)) => {
187                        *state = StreamState::Header(
188                            Cursor::new(v.encode_headers(boundary).into_bytes()),
189                            v.content,
190                        );
191                    }
192                    Poll::Ready(None) => {
193                        *state = StreamState::Footer(Cursor::new(
194                            format!("\r\n--{boundary}--\r\n").into_bytes(),
195                        ));
196                    }
197                    Poll::Pending => return Poll::Pending,
198                },
199                StreamState::Header(r, _) => {
200                    let cur = buf.filled().len();
201                    match Pin::new(r).poll_read(cx, buf) {
202                        Poll::Ready(Ok(())) => (),
203                        v => return v,
204                    }
205                    if cur == buf.filled().len() {
206                        // EOF, move on
207                        if let StreamState::Header(_, next) =
208                            std::mem::replace(state, StreamState::Waiting)
209                        {
210                            *state = StreamState::Raw(next);
211                        } else {
212                            unreachable!()
213                        }
214                    } else {
215                        return Poll::Ready(Ok(()));
216                    }
217                }
218                StreamState::Raw(r) => {
219                    let cur = buf.filled().len();
220                    match r.as_mut().poll_read(cx, buf) {
221                        Poll::Ready(Ok(())) => (),
222                        v => return v,
223                    }
224                    if cur == buf.filled().len() {
225                        // EOF, move on
226                        *state = StreamState::Waiting;
227                    } else {
228                        return Poll::Ready(Ok(()));
229                    }
230                }
231                StreamState::Footer(r) => {
232                    let cur = buf.filled().len();
233                    match Pin::new(r).poll_read(cx, buf) {
234                        Poll::Ready(Ok(())) => (),
235                        v => return v,
236                    }
237                    if cur == buf.filled().len() {
238                        // EOF, move on
239                        return Poll::Ready(Ok(()));
240                    } else {
241                        return Poll::Ready(Ok(()));
242                    }
243                }
244            }
245        }
246    }
247}