1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use crate::http::GQLRequest;
use crate::{ParseRequestError, Request};
use bytes::Bytes;
use futures::io::AsyncRead;
use futures::stream::{self, Stream};
use multer::{Constraints, Multipart, SizeLimit};
use std::collections::HashMap;
use std::io::{self, Seek, SeekFrom, Write};
use std::pin::Pin;
use std::task::Poll;
#[derive(Default, Clone)]
#[non_exhaustive]
pub struct MultipartOptions {
pub max_file_size: Option<usize>,
pub max_num_files: Option<usize>,
}
pub async fn receive_multipart(
body: impl AsyncRead + Send + 'static,
boundary: impl Into<String>,
opts: MultipartOptions,
) -> Result<Request, ParseRequestError> {
let mut multipart = Multipart::new_with_constraints(
reader_stream(body),
boundary,
Constraints::new().size_limit({
let mut limit = SizeLimit::new();
if let (Some(max_file_size), Some(max_num_files)) =
(opts.max_file_size, opts.max_file_size)
{
limit = limit.whole_stream((max_file_size * max_num_files) as u64);
}
if let Some(max_file_size) = opts.max_file_size {
limit = limit.per_field(max_file_size as u64);
}
limit
}),
);
let mut request = None;
let mut map = None;
let mut files = Vec::new();
while let Some(mut field) = multipart.next_field().await? {
match field.name() {
Some("operations") => {
let request_str = field.text().await?;
request = Some(
serde_json::from_str::<GQLRequest>(&request_str)
.map_err(ParseRequestError::InvalidRequest)?
.into(),
);
}
Some("map") => {
let map_str = field.text().await?;
map = Some(
serde_json::from_str::<HashMap<String, Vec<String>>>(&map_str)
.map_err(ParseRequestError::InvalidFilesMap)?,
);
}
_ => {
if let Some(name) = field.name().map(ToString::to_string) {
if let Some(filename) = field.file_name().map(ToString::to_string) {
let content_type = field.content_type().map(|mime| mime.to_string());
let mut file = tempfile::tempfile().map_err(ParseRequestError::Io)?;
while let Some(chunk) = field.chunk().await.unwrap() {
file.write(&chunk).map_err(ParseRequestError::Io)?;
}
file.seek(SeekFrom::Start(0))?;
files.push((name, filename, content_type, file));
}
}
}
}
}
let mut request: Request = request.ok_or(ParseRequestError::MissingOperatorsPart)?;
let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?;
for (name, filename, content_type, file) in files {
if let Some(var_paths) = map.remove(&name) {
for var_path in var_paths {
request.set_upload(
&var_path,
filename.clone(),
content_type.clone(),
file.try_clone().unwrap(),
);
}
}
}
if !map.is_empty() {
return Err(ParseRequestError::MissingFiles);
}
Ok(request)
}
fn reader_stream(
reader: impl AsyncRead + Send + 'static,
) -> impl Stream<Item = io::Result<Bytes>> + Send + 'static {
let mut buf = [0u8; 2048];
let mut reader = Box::pin(reader);
stream::poll_fn(move |cx| {
Poll::Ready(
match futures::ready!(Pin::new(&mut reader).poll_read(cx, &mut buf)?) {
0 => None,
size => Some(Ok(Bytes::copy_from_slice(&buf[..size]))),
},
)
})
}