async_graphql/http/
mod.rs

1//! A helper module that supports HTTP
2
3#[cfg(feature = "graphiql")]
4mod graphiql_source;
5mod multipart;
6mod multipart_subscribe;
7mod websocket;
8
9use futures_util::io::{AsyncRead, AsyncReadExt};
10#[cfg(feature = "graphiql")]
11pub use graphiql_source::{Credentials, GraphiQLSource};
12pub use multipart::MultipartOptions;
13pub use multipart_subscribe::{create_multipart_mixed_stream, is_accept_multipart_mixed};
14use serde::Deserialize;
15pub use websocket::{
16    ALL_WEBSOCKET_PROTOCOLS, ClientMessage, DefaultOnConnInitType, DefaultOnPingType,
17    Protocols as WebSocketProtocols, WebSocket, WsMessage, default_on_connection_init,
18    default_on_ping,
19};
20
21use crate::{BatchRequest, ParseRequestError, Request};
22
23/// Parse a GraphQL request from a query string.
24pub fn parse_query_string(input: &str) -> Result<Request, ParseRequestError> {
25    #[derive(Deserialize)]
26    struct RequestSerde {
27        #[serde(default)]
28        pub query: String,
29        pub operation_name: Option<String>,
30        pub variables: Option<String>,
31        pub extensions: Option<String>,
32    }
33
34    let request: RequestSerde = serde_urlencoded::from_str(input).map_err(std::io::Error::other)?;
35    let variables = request
36        .variables
37        .map(|data| serde_json::from_str(&data))
38        .transpose()
39        .map_err(|err| std::io::Error::other(format!("invalid variables: {}", err)))?
40        .unwrap_or_default();
41    let extensions = request
42        .extensions
43        .map(|data| serde_json::from_str(&data))
44        .transpose()
45        .map_err(|err| std::io::Error::other(format!("invalid extensions: {}", err)))?
46        .unwrap_or_default();
47
48    Ok(Request {
49        operation_name: request.operation_name,
50        variables,
51        extensions,
52        ..Request::new(request.query)
53    })
54}
55
56/// Receive a GraphQL request from a content type and body.
57pub async fn receive_body(
58    content_type: Option<impl AsRef<str>>,
59    body: impl AsyncRead + Send,
60    opts: MultipartOptions,
61) -> Result<Request, ParseRequestError> {
62    receive_batch_body(content_type, body, opts)
63        .await?
64        .into_single()
65}
66
67/// Receive a GraphQL request from a content type and body.
68pub async fn receive_batch_body(
69    content_type: Option<impl AsRef<str>>,
70    body: impl AsyncRead + Send,
71    opts: MultipartOptions,
72) -> Result<BatchRequest, ParseRequestError> {
73    // if no content-type header is set, we default to json
74    let content_type = content_type
75        .as_ref()
76        .map(AsRef::as_ref)
77        .unwrap_or("application/graphql-response+json");
78
79    let content_type: mime::Mime = content_type.parse()?;
80
81    match (content_type.type_(), content_type.subtype()) {
82        // try to use multipart
83        (mime::MULTIPART, _) => {
84            if let Some(boundary) = content_type.get_param("boundary") {
85                multipart::receive_batch_multipart(body, boundary.to_string(), opts).await
86            } else {
87                Err(ParseRequestError::InvalidMultipart(
88                    multer::Error::NoBoundary,
89                ))
90            }
91        }
92        // application/json (currently)
93        _ => receive_batch_body_no_multipart(&content_type, body).await,
94    }
95}
96
97/// Receives a GraphQL query which is json but NOT multipart
98/// This method is only to avoid recursive calls with [``receive_batch_body``]
99/// and [``multipart::receive_batch_multipart``]
100pub(super) async fn receive_batch_body_no_multipart(
101    content_type: &mime::Mime,
102    body: impl AsyncRead + Send,
103) -> Result<BatchRequest, ParseRequestError> {
104    assert_ne!(content_type.type_(), mime::MULTIPART, "received multipart");
105    receive_batch_json(body).await
106}
107
108/// Receive a GraphQL request from a body as JSON.
109pub async fn receive_json(body: impl AsyncRead) -> Result<Request, ParseRequestError> {
110    receive_batch_json(body).await?.into_single()
111}
112
113/// Receive a GraphQL batch request from a body as JSON.
114pub async fn receive_batch_json(body: impl AsyncRead) -> Result<BatchRequest, ParseRequestError> {
115    let mut data = Vec::new();
116    futures_util::pin_mut!(body);
117    body.read_to_end(&mut data)
118        .await
119        .map_err(ParseRequestError::Io)?;
120    serde_json::from_slice::<BatchRequest>(&data)
121        .map_err(|e| ParseRequestError::InvalidRequest(Box::new(e)))
122}
123
124#[cfg(test)]
125mod tests {
126    use std::collections::HashMap;
127
128    use async_graphql_value::Extensions;
129
130    use super::*;
131    use crate::{Variables, value};
132
133    #[test]
134    fn test_parse_query_string() {
135        let request = parse_query_string("variables=%7B%7D&extensions=%7B%22persistedQuery%22%3A%7B%22sha256Hash%22%3A%22cde5de0a350a19c59f8ddcd9646e5f260b2a7d5649ff6be8e63e9462934542c3%22%2C%22version%22%3A1%7D%7D").unwrap();
136        assert_eq!(request.query.as_str(), "");
137        assert_eq!(request.variables, Variables::default());
138        assert_eq!(request.extensions, {
139            let mut extensions = HashMap::new();
140            extensions.insert("persistedQuery".to_string(), value!({
141                "sha256Hash": "cde5de0a350a19c59f8ddcd9646e5f260b2a7d5649ff6be8e63e9462934542c3",
142                "version": 1,
143            }));
144            Extensions(extensions)
145        });
146
147        let request = parse_query_string("query={a}&variables=%7B%22a%22%3A10%7D").unwrap();
148        assert_eq!(request.query.as_str(), "{a}");
149        assert_eq!(
150            request.variables,
151            Variables::from_value(value!({ "a" : 10 }))
152        );
153    }
154}