1use std::{
2 io::{self, Cursor},
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use rocket::{
8 futures::Stream,
9 http::{ContentType, Header, HeaderMap, Status},
10 response::{Responder, Result},
11 tokio::io::{AsyncRead, ReadBuf},
12 Request, Response,
13};
14
15pub struct MultipartSection<'r> {
17 headers: HeaderMap<'static>,
18 content: Pin<Box<dyn AsyncRead + Send + 'r>>,
19}
20
21impl<'r> MultipartSection<'r> {
22 pub fn new<T: AsyncRead + Send + 'r>(reader: T) -> Self {
26 Self {
27 headers: HeaderMap::new(),
28 content: Box::pin(reader),
29 }
30 }
31
32 pub fn from_box(reader: Box<dyn AsyncRead + Send + 'r>) -> Self {
36 Self {
37 headers: HeaderMap::new(),
38 content: Box::into_pin(reader),
39 }
40 }
41
42 pub fn from_slice(slice: &'r [u8]) -> Self {
44 Self {
45 headers: HeaderMap::new(),
46 content: Box::pin(Cursor::new(slice)),
47 }
48 }
49
50 #[cfg(feature = "json")]
54 pub fn from_json<T: ?Sized + serde::Serialize>(obj: &T) -> serde_json::Result<Self> {
55 let slice = serde_json::to_vec(obj)?;
56 Ok(Self {
57 headers: HeaderMap::new(),
58 content: Box::pin(Cursor::new(slice)),
59 })
60 }
61
62 pub fn add_header(mut self, header: impl Into<Header<'static>>) -> Self {
65 self.headers.add(header);
66 self
67 }
68
69 pub fn replace_header(mut self, header: impl Into<Header<'static>>) -> Self {
72 self.headers.replace(header);
73 self
74 }
75
76 fn encode_headers(&self, boundary: &str) -> String {
77 let mut s = format!("\r\n--{boundary}\r\n");
78 for h in self.headers.iter() {
79 s.push_str(h.name.as_str());
80 s.push_str(": ");
81 s.push_str(h.value());
82 s.push_str("\r\n");
83 }
84 s.push_str("\r\n");
85 s
86 }
87}
88
89pub struct MultipartStream<T> {
91 boundary: String,
92 stream: T,
93 sub_type: &'static str,
94}
95
96impl<T> MultipartStream<T> {
97 pub fn new(boundary: impl Into<String>, stream: T) -> Self {
100 Self {
101 boundary: boundary.into(),
102 stream,
103 sub_type: "mixed",
104 }
105 }
106
107 #[cfg(feature = "rand")]
112 pub fn new_random(stream: T) -> Self {
113 use rand::{distributions::Alphanumeric, Rng};
114
115 Self {
116 boundary: rand::thread_rng()
117 .sample_iter(Alphanumeric)
118 .map(|v| v as char)
119 .take(15)
120 .collect(),
121 stream,
122 sub_type: "mixed",
123 }
124 }
125
126 pub fn with_subtype(mut self, sub_type: &'static str) -> Self {
128 self.sub_type = sub_type;
129 self
130 }
131}
132
133impl<'r, 'o: 'r, T: Stream<Item = MultipartSection<'o>> + Send + 'o> Responder<'r, 'o>
134 for MultipartStream<T>
135{
136 fn respond_to(self, _r: &'r Request<'_>) -> Result<'o> {
137 Response::build()
138 .status(Status::Ok)
139 .header(
140 ContentType::new("multipart", self.sub_type)
141 .with_params(("boundary", self.boundary.clone())),
142 )
143 .streamed_body(MultipartStreamInner(
144 self.boundary,
145 self.stream,
146 StreamState::Waiting,
147 ))
148 .ok()
149 }
150}
151
152struct MultipartStreamInner<'r, T>(String, T, StreamState<'r>);
153
154impl<'r, T> MultipartStreamInner<'r, T> {
155 fn inner(self: Pin<&mut Self>) -> (&str, Pin<&mut T>, &mut StreamState<'r>) {
156 let this = unsafe { self.get_unchecked_mut() };
159 (
160 &this.0,
161 unsafe { Pin::new_unchecked(&mut this.1) },
162 &mut this.2,
163 )
164 }
165}
166
167enum StreamState<'r> {
168 Waiting,
169 Header(Cursor<Vec<u8>>, Pin<Box<dyn AsyncRead + Send + 'r>>),
170 Raw(Pin<Box<dyn AsyncRead + Send + 'r>>),
171 Footer(Cursor<Vec<u8>>),
172}
173
174impl<'r, T: Stream<Item = MultipartSection<'r>> + Send + 'r> AsyncRead
175 for MultipartStreamInner<'r, T>
176{
177 fn poll_read(
178 self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 buf: &mut ReadBuf<'_>,
181 ) -> Poll<io::Result<()>> {
182 let (boundary, mut stream, state) = self.inner();
183 loop {
184 match state {
185 StreamState::Waiting => match stream.as_mut().poll_next(cx) {
186 Poll::Ready(Some(v)) => {
187 *state = StreamState::Header(
188 Cursor::new(v.encode_headers(boundary).into_bytes()),
189 v.content,
190 );
191 }
192 Poll::Ready(None) => {
193 *state = StreamState::Footer(Cursor::new(
194 format!("\r\n--{boundary}--\r\n").into_bytes(),
195 ));
196 }
197 Poll::Pending => return Poll::Pending,
198 },
199 StreamState::Header(r, _) => {
200 let cur = buf.filled().len();
201 match Pin::new(r).poll_read(cx, buf) {
202 Poll::Ready(Ok(())) => (),
203 v => return v,
204 }
205 if cur == buf.filled().len() {
206 if let StreamState::Header(_, next) =
208 std::mem::replace(state, StreamState::Waiting)
209 {
210 *state = StreamState::Raw(next);
211 } else {
212 unreachable!()
213 }
214 } else {
215 return Poll::Ready(Ok(()));
216 }
217 }
218 StreamState::Raw(r) => {
219 let cur = buf.filled().len();
220 match r.as_mut().poll_read(cx, buf) {
221 Poll::Ready(Ok(())) => (),
222 v => return v,
223 }
224 if cur == buf.filled().len() {
225 *state = StreamState::Waiting;
227 } else {
228 return Poll::Ready(Ok(()));
229 }
230 }
231 StreamState::Footer(r) => {
232 let cur = buf.filled().len();
233 match Pin::new(r).poll_read(cx, buf) {
234 Poll::Ready(Ok(())) => (),
235 v => return v,
236 }
237 if cur == buf.filled().len() {
238 return Poll::Ready(Ok(()));
240 } else {
241 return Poll::Ready(Ok(()));
242 }
243 }
244 }
245 }
246 }
247}