async_graphql_viz/
extract.rs1use std::collections::HashMap;
2
3use async_graphql::{http::MultipartOptions, ParseRequestError};
4
5use viz_core::{http, types::Multipart, Context, Error, Extract, Result};
6use viz_utils::{
7 futures::{future::BoxFuture, TryStreamExt},
8 serde::json,
9};
10
11pub struct GraphQLRequest(pub async_graphql::Request);
13
14impl GraphQLRequest {
15 #[must_use]
17 pub fn into_inner(self) -> async_graphql::Request {
18 self.0
19 }
20}
21
22pub mod rejection {
24 use async_graphql::ParseRequestError;
25 use viz_core::{http, Response};
26
27 pub struct GraphQLRejection(pub ParseRequestError);
29
30 impl From<GraphQLRejection> for Response {
31 fn from(gr: GraphQLRejection) -> Self {
32 match gr.0 {
33 ParseRequestError::PayloadTooLarge => http::StatusCode::PAYLOAD_TOO_LARGE.into(),
34 bad_request => (http::StatusCode::BAD_REQUEST, format!("{:?}", bad_request)).into(),
35 }
36 }
37 }
38
39 impl From<ParseRequestError> for GraphQLRejection {
40 fn from(err: ParseRequestError) -> Self {
41 GraphQLRejection(err)
42 }
43 }
44}
45
46impl Extract for GraphQLRequest {
47 type Error = rejection::GraphQLRejection;
48
49 fn extract(cx: &mut Context) -> BoxFuture<'_, Result<Self, Self::Error>> {
50 Box::pin(async move {
51 Ok(GraphQLRequest(
52 GraphQLBatchRequest::extract(cx).await?.0.into_single()?,
53 ))
54 })
55 }
56}
57
58pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest);
60
61impl GraphQLBatchRequest {
62 #[must_use]
64 pub fn into_inner(self) -> async_graphql::BatchRequest {
65 self.0
66 }
67}
68
69impl Extract for GraphQLBatchRequest {
70 type Error = rejection::GraphQLRejection;
71
72 fn extract(cx: &mut Context) -> BoxFuture<'_, Result<Self, Self::Error>> {
73 Box::pin(async move {
74 if http::Method::GET == cx.method() {
75 Ok(Self(async_graphql::BatchRequest::Single(
76 cx.query()
77 .map_err(|e| ParseRequestError::InvalidRequest(Box::from(e)))?,
78 )))
79 } else {
80 if let Ok(multipart) = cx.multipart() {
81 if let Ok(mut state) = multipart.state().lock() {
82 let opts = MultipartOptions::default();
83 let mut limits = state.limits_mut();
84 limits.file_size = opts.max_file_size;
85 limits.files = opts.max_num_files;
86 }
87
88 Ok(Self(receive_batch_multipart(multipart).await.map_err(
89 |e| ParseRequestError::InvalidRequest(Box::from(e)),
90 )?))
91 } else {
92 Ok(Self(cx.json().await.map_err(|e| {
93 ParseRequestError::InvalidRequest(Box::from(e))
94 })?))
95 }
96 }
97 })
98 }
99}
100
101async fn receive_batch_multipart(mut multipart: Multipart) -> Result<async_graphql::BatchRequest> {
102 let mut request = None;
103 let mut map = None;
104 let mut files = Vec::new();
105
106 while let Some(mut field) = multipart.try_next().await? {
107 let content_type = field
110 .content_type
111 .to_owned()
112 .unwrap_or(mime::APPLICATION_JSON);
113
114 let name = field.name.clone();
115
116 match name.as_str() {
117 "operations" => {
118 let body = field.bytes().await?;
119 request = Some(json::from_slice(&body)?)
120 }
121 "map" => {
122 let map_bytes = field.bytes().await?;
123
124 match (content_type.type_(), content_type.subtype()) {
125 #[cfg(feature = "cbor")]
134 (mime::OCTET_STREAM, _) | (mime::APPLICATION, mime::OCTET_STREAM) => {
135 map = Some(
136 serde_cbor::from_slice::<HashMap<String, Vec<String>>>(&map_bytes)
137 .map_err(|e| ParseRequestError::InvalidFilesMap(Box::new(e)))?,
138 );
139 }
140 _ => {
142 map = Some(
143 json::from_slice::<HashMap<String, Vec<String>>>(&map_bytes)
144 .map_err(|e| ParseRequestError::InvalidFilesMap(Box::new(e)))?,
145 );
146 }
147 }
148 }
149 _ => {
150 if !name.is_empty() {
151 if let Some(filename) = field.filename.to_owned() {
152 let mut file = tempfile::tempfile().map_err(ParseRequestError::Io)?;
153 field.copy_to_file(&mut file).await?;
154 files.push((name, filename, Some(content_type.to_string()), file));
155 }
156 }
157 }
158 }
159 }
160
161 let mut request = request.ok_or(ParseRequestError::MissingOperatorsPart)?;
162 let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?;
163
164 for (name, filename, content_type, content) in files {
165 if let Some(var_paths) = map.remove(&name) {
166 let upload = async_graphql::UploadValue {
167 filename,
168 content_type,
169 content,
170 };
171
172 for var_path in var_paths {
173 match &mut request {
174 async_graphql::BatchRequest::Single(request) => {
175 request.set_upload(&var_path, upload.try_clone()?);
176 }
177 async_graphql::BatchRequest::Batch(requests) => {
178 let mut s = var_path.splitn(2, '.');
179 let idx = s.next().and_then(|idx| idx.parse::<usize>().ok());
180 let path = s.next();
181
182 if let (Some(idx), Some(path)) = (idx, path) {
183 if let Some(request) = requests.get_mut(idx) {
184 request.set_upload(path, upload.try_clone()?);
185 }
186 }
187 }
188 }
189 }
190 }
191 }
192
193 if !map.is_empty() {
194 return Err(Error::from(async_graphql::ParseRequestError::MissingFiles));
195 }
196
197 Ok(request)
198}