graphql_starter/graphql/
extract.rs

1//! Manually implementing GraphQLBatchRequest to customize multipart options, as [recommended](https://github.com/async-graphql/async-graphql/issues/1220).
2
3use std::{io::ErrorKind, marker::PhantomData};
4
5use async_graphql::{futures_util::TryStreamExt, http::MultipartOptions, ParseRequestError};
6use axum::{
7    extract::{FromRef, FromRequest, Request},
8    http::{self, Method},
9    response::IntoResponse,
10};
11use tokio_util::compat::TokioAsyncReadCompatExt;
12
13/// Extractor for GraphQL request.
14pub struct GraphQLRequest<R = rejection::GraphQLRejection>(pub async_graphql::Request, PhantomData<R>);
15
16impl<R> GraphQLRequest<R> {
17    /// Unwraps the value to `async_graphql::Request`.
18    #[must_use]
19    pub fn into_inner(self) -> async_graphql::Request {
20        self.0
21    }
22}
23
24/// Rejection response types.
25pub mod rejection {
26    use async_graphql::ParseRequestError;
27    use axum::{
28        http::StatusCode,
29        response::{IntoResponse, Response},
30    };
31
32    use crate::error::ApiError;
33
34    /// Rejection used for [`GraphQLRequest`](super::GraphQLRequest).
35    pub struct GraphQLRejection(pub ParseRequestError);
36
37    impl IntoResponse for GraphQLRejection {
38        fn into_response(self) -> Response {
39            match self.0 {
40                ParseRequestError::PayloadTooLarge => {
41                    tracing::warn!("[413 Payload Too Large] Received a GraphQL request with a payload too large");
42                    ApiError::new(StatusCode::PAYLOAD_TOO_LARGE, "Payload too large").into_response()
43                }
44                bad_request => {
45                    let msg = bad_request.to_string();
46                    tracing::warn!("[400 Bad Request] {msg}");
47                    ApiError::new(StatusCode::BAD_REQUEST, msg).into_response()
48                }
49            }
50        }
51    }
52
53    impl From<ParseRequestError> for GraphQLRejection {
54        fn from(err: ParseRequestError) -> Self {
55            GraphQLRejection(err)
56        }
57    }
58}
59
60impl<S, R> FromRequest<S> for GraphQLRequest<R>
61where
62    S: Send + Sync,
63    MultipartOptions: FromRef<S>,
64    R: IntoResponse + From<ParseRequestError>,
65{
66    type Rejection = R;
67
68    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
69        Ok(GraphQLRequest(
70            GraphQLBatchRequest::<R>::from_request(req, state)
71                .await?
72                .0
73                .into_single()?,
74            PhantomData,
75        ))
76    }
77}
78
79/// Extractor for GraphQL batch request.
80pub struct GraphQLBatchRequest<R = rejection::GraphQLRejection>(pub async_graphql::BatchRequest, PhantomData<R>);
81
82impl<R> GraphQLBatchRequest<R> {
83    /// Unwraps the value to `async_graphql::BatchRequest`.
84    #[must_use]
85    pub fn into_inner(self) -> async_graphql::BatchRequest {
86        self.0
87    }
88}
89
90impl<S, R> FromRequest<S> for GraphQLBatchRequest<R>
91where
92    S: Send + Sync,
93    R: IntoResponse + From<ParseRequestError>,
94    MultipartOptions: FromRef<S>,
95{
96    type Rejection = R;
97
98    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
99        if req.method() == Method::GET {
100            let uri = req.uri();
101            let res = async_graphql::http::parse_query_string(uri.query().unwrap_or_default()).map_err(|err| {
102                ParseRequestError::Io(std::io::Error::new(
103                    ErrorKind::Other,
104                    format!("failed to parse graphql request from uri query: {}", err),
105                ))
106            });
107            Ok(Self(async_graphql::BatchRequest::Single(res?), PhantomData))
108        } else {
109            let content_type = req
110                .headers()
111                .get(http::header::CONTENT_TYPE)
112                .and_then(|value| value.to_str().ok())
113                .map(ToString::to_string);
114            let body_stream = req
115                .into_body()
116                .into_data_stream()
117                .map_err(|err| std::io::Error::new(ErrorKind::Other, err.to_string()));
118            let body_reader = tokio_util::io::StreamReader::new(body_stream).compat();
119            Ok(Self(
120                async_graphql::http::receive_batch_body(content_type, body_reader, FromRef::from_ref(state)).await?,
121                PhantomData,
122            ))
123        }
124    }
125}