async_graphql_viz/
extract.rs

1use std::collections::HashMap;
2
3use async_graphql::{http::MultipartOptions, ParseRequestError};
4
5use viz_core::{http, types::Multipart, Context, Error, Extract, Result};
6use viz_utils::{
7    futures::{future::BoxFuture, TryStreamExt},
8    serde::json,
9};
10
11/// Extractor for GraphQL request.
12pub struct GraphQLRequest(pub async_graphql::Request);
13
14impl GraphQLRequest {
15    /// Unwraps the value to `async_graphql::Request`.
16    #[must_use]
17    pub fn into_inner(self) -> async_graphql::Request {
18        self.0
19    }
20}
21
22/// Rejection response types.
23pub mod rejection {
24    use async_graphql::ParseRequestError;
25    use viz_core::{http, Response};
26
27    /// Rejection used for [`GraphQLRequest`](GraphQLRequest).
28    pub struct GraphQLRejection(pub ParseRequestError);
29
30    impl From<GraphQLRejection> for Response {
31        fn from(gr: GraphQLRejection) -> Self {
32            match gr.0 {
33                ParseRequestError::PayloadTooLarge => http::StatusCode::PAYLOAD_TOO_LARGE.into(),
34                bad_request => (http::StatusCode::BAD_REQUEST, format!("{:?}", bad_request)).into(),
35            }
36        }
37    }
38
39    impl From<ParseRequestError> for GraphQLRejection {
40        fn from(err: ParseRequestError) -> Self {
41            GraphQLRejection(err)
42        }
43    }
44}
45
46impl Extract for GraphQLRequest {
47    type Error = rejection::GraphQLRejection;
48
49    fn extract(cx: &mut Context) -> BoxFuture<'_, Result<Self, Self::Error>> {
50        Box::pin(async move {
51            Ok(GraphQLRequest(
52                GraphQLBatchRequest::extract(cx).await?.0.into_single()?,
53            ))
54        })
55    }
56}
57
58/// Extractor for GraphQL batch request.
59pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest);
60
61impl GraphQLBatchRequest {
62    /// Unwraps the value to `async_graphql::BatchRequest`.
63    #[must_use]
64    pub fn into_inner(self) -> async_graphql::BatchRequest {
65        self.0
66    }
67}
68
69impl Extract for GraphQLBatchRequest {
70    type Error = rejection::GraphQLRejection;
71
72    fn extract(cx: &mut Context) -> BoxFuture<'_, Result<Self, Self::Error>> {
73        Box::pin(async move {
74            if http::Method::GET == cx.method() {
75                Ok(Self(async_graphql::BatchRequest::Single(
76                    cx.query()
77                        .map_err(|e| ParseRequestError::InvalidRequest(Box::from(e)))?,
78                )))
79            } else {
80                if let Ok(multipart) = cx.multipart() {
81                    if let Ok(mut state) = multipart.state().lock() {
82                        let opts = MultipartOptions::default();
83                        let mut limits = state.limits_mut();
84                        limits.file_size = opts.max_file_size;
85                        limits.files = opts.max_num_files;
86                    }
87
88                    Ok(Self(receive_batch_multipart(multipart).await.map_err(
89                        |e| ParseRequestError::InvalidRequest(Box::from(e)),
90                    )?))
91                } else {
92                    Ok(Self(cx.json().await.map_err(|e| {
93                        ParseRequestError::InvalidRequest(Box::from(e))
94                    })?))
95                }
96            }
97        })
98    }
99}
100
101async fn receive_batch_multipart(mut multipart: Multipart) -> Result<async_graphql::BatchRequest> {
102    let mut request = None;
103    let mut map = None;
104    let mut files = Vec::new();
105
106    while let Some(mut field) = multipart.try_next().await? {
107        // in multipart, each field / file can actually have a own Content-Type.
108        // We use this to determine the encoding of the graphql query
109        let content_type = field
110            .content_type
111            .to_owned()
112            .unwrap_or(mime::APPLICATION_JSON);
113
114        let name = field.name.clone();
115
116        match name.as_str() {
117            "operations" => {
118                let body = field.bytes().await?;
119                request = Some(json::from_slice(&body)?)
120            }
121            "map" => {
122                let map_bytes = field.bytes().await?;
123
124                match (content_type.type_(), content_type.subtype()) {
125                    // cbor is in application/octet-stream.
126                    // TODO: wait for mime to add application/cbor and match against that too
127                    // Note: we actually differ here from the inoffical spec for this:
128                    // (https://github.com/jaydenseric/graphql-multipart-request-spec#multipart-form-field-structure)
129                    // It says: "map: A JSON encoded map of where files occurred in the operations.
130                    // For each file, the key is the file multipart form field name and the value is an array of operations paths."
131                    // However, I think, that since we accept CBOR as operation, which is valid, we should also accept it
132                    // as the mapping for the files.
133                    #[cfg(feature = "cbor")]
134                    (mime::OCTET_STREAM, _) | (mime::APPLICATION, mime::OCTET_STREAM) => {
135                        map = Some(
136                            serde_cbor::from_slice::<HashMap<String, Vec<String>>>(&map_bytes)
137                                .map_err(|e| ParseRequestError::InvalidFilesMap(Box::new(e)))?,
138                        );
139                    }
140                    // default to json
141                    _ => {
142                        map = Some(
143                            json::from_slice::<HashMap<String, Vec<String>>>(&map_bytes)
144                                .map_err(|e| ParseRequestError::InvalidFilesMap(Box::new(e)))?,
145                        );
146                    }
147                }
148            }
149            _ => {
150                if !name.is_empty() {
151                    if let Some(filename) = field.filename.to_owned() {
152                        let mut file = tempfile::tempfile().map_err(ParseRequestError::Io)?;
153                        field.copy_to_file(&mut file).await?;
154                        files.push((name, filename, Some(content_type.to_string()), file));
155                    }
156                }
157            }
158        }
159    }
160
161    let mut request = request.ok_or(ParseRequestError::MissingOperatorsPart)?;
162    let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?;
163
164    for (name, filename, content_type, content) in files {
165        if let Some(var_paths) = map.remove(&name) {
166            let upload = async_graphql::UploadValue {
167                filename,
168                content_type,
169                content,
170            };
171
172            for var_path in var_paths {
173                match &mut request {
174                    async_graphql::BatchRequest::Single(request) => {
175                        request.set_upload(&var_path, upload.try_clone()?);
176                    }
177                    async_graphql::BatchRequest::Batch(requests) => {
178                        let mut s = var_path.splitn(2, '.');
179                        let idx = s.next().and_then(|idx| idx.parse::<usize>().ok());
180                        let path = s.next();
181
182                        if let (Some(idx), Some(path)) = (idx, path) {
183                            if let Some(request) = requests.get_mut(idx) {
184                                request.set_upload(path, upload.try_clone()?);
185                            }
186                        }
187                    }
188                }
189            }
190        }
191    }
192
193    if !map.is_empty() {
194        return Err(Error::from(async_graphql::ParseRequestError::MissingFiles));
195    }
196
197    Ok(request)
198}