graphql_starter/graphql/
extract.rs

1//! Manually implementing GraphQLBatchRequest to customize multipart options, as [recommended](https://github.com/async-graphql/async-graphql/issues/1220).
2
3use std::{io::ErrorKind, marker::PhantomData};
4
5use async_graphql::{futures_util::TryStreamExt, http::MultipartOptions, ParseRequestError};
6use axum::{
7    extract::{FromRef, FromRequest, Request},
8    http::{self, Method},
9    response::IntoResponse,
10};
11use tokio_util::compat::TokioAsyncReadCompatExt;
12
13/// Extractor for GraphQL request.
14pub struct GraphQLRequest<R = rejection::GraphQLRejection>(pub async_graphql::Request, PhantomData<R>);
15
16impl<R> GraphQLRequest<R> {
17    /// Unwraps the value to `async_graphql::Request`.
18    #[must_use]
19    pub fn into_inner(self) -> async_graphql::Request {
20        self.0
21    }
22}
23
24/// Rejection response types.
25pub mod rejection {
26    use async_graphql::ParseRequestError;
27    use axum::{
28        http::StatusCode,
29        response::{IntoResponse, Response},
30    };
31
32    use crate::error::{ApiError, Error, GenericErrorCode};
33
34    /// Rejection used for [`GraphQLRequest`](super::GraphQLRequest).
35    pub struct GraphQLRejection(pub ParseRequestError);
36
37    impl IntoResponse for GraphQLRejection {
38        fn into_response(self) -> Response {
39            match self.0 {
40                ParseRequestError::PayloadTooLarge => {
41                    tracing::warn!("[413 Payload Too Large] Received a GraphQL request with a payload too large");
42                    ApiError::new(StatusCode::PAYLOAD_TOO_LARGE, "Payload too large").into_response()
43                }
44                bad_request => ApiError::from_err(Error::new(GenericErrorCode::BadRequest).with_source(bad_request))
45                    .into_response(),
46            }
47        }
48    }
49
50    impl From<ParseRequestError> for GraphQLRejection {
51        fn from(err: ParseRequestError) -> Self {
52            GraphQLRejection(err)
53        }
54    }
55}
56
57impl<S, R> FromRequest<S> for GraphQLRequest<R>
58where
59    S: Send + Sync,
60    MultipartOptions: FromRef<S>,
61    R: IntoResponse + From<ParseRequestError>,
62{
63    type Rejection = R;
64
65    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
66        Ok(GraphQLRequest(
67            GraphQLBatchRequest::<R>::from_request(req, state)
68                .await?
69                .0
70                .into_single()?,
71            PhantomData,
72        ))
73    }
74}
75
76/// Extractor for GraphQL batch request.
77pub struct GraphQLBatchRequest<R = rejection::GraphQLRejection>(pub async_graphql::BatchRequest, PhantomData<R>);
78
79impl<R> GraphQLBatchRequest<R> {
80    /// Unwraps the value to `async_graphql::BatchRequest`.
81    #[must_use]
82    pub fn into_inner(self) -> async_graphql::BatchRequest {
83        self.0
84    }
85}
86
87impl<S, R> FromRequest<S> for GraphQLBatchRequest<R>
88where
89    S: Send + Sync,
90    R: IntoResponse + From<ParseRequestError>,
91    MultipartOptions: FromRef<S>,
92{
93    type Rejection = R;
94
95    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
96        if req.method() == Method::GET {
97            let uri = req.uri();
98            let res = async_graphql::http::parse_query_string(uri.query().unwrap_or_default()).map_err(|err| {
99                ParseRequestError::Io(std::io::Error::new(
100                    ErrorKind::Other,
101                    format!("failed to parse graphql request from uri query: {}", err),
102                ))
103            });
104            Ok(Self(async_graphql::BatchRequest::Single(res?), PhantomData))
105        } else {
106            let content_type = req
107                .headers()
108                .get(http::header::CONTENT_TYPE)
109                .and_then(|value| value.to_str().ok())
110                .map(ToString::to_string);
111            let body_stream = req
112                .into_body()
113                .into_data_stream()
114                .map_err(|err| std::io::Error::new(ErrorKind::Other, err.to_string()));
115            let body_reader = tokio_util::io::StreamReader::new(body_stream).compat();
116            Ok(Self(
117                async_graphql::http::receive_batch_body(content_type, body_reader, FromRef::from_ref(state)).await?,
118                PhantomData,
119            ))
120        }
121    }
122}