graphql_starter/graphql/
extract.rs1use std::{io::ErrorKind, marker::PhantomData};
4
5use async_graphql::{futures_util::TryStreamExt, http::MultipartOptions, ParseRequestError};
6use axum::{
7 extract::{FromRef, FromRequest, Request},
8 http::{self, Method},
9 response::IntoResponse,
10};
11use tokio_util::compat::TokioAsyncReadCompatExt;
12
13pub struct GraphQLRequest<R = rejection::GraphQLRejection>(pub async_graphql::Request, PhantomData<R>);
15
16impl<R> GraphQLRequest<R> {
17 #[must_use]
19 pub fn into_inner(self) -> async_graphql::Request {
20 self.0
21 }
22}
23
24pub mod rejection {
26 use async_graphql::ParseRequestError;
27 use axum::{
28 http::StatusCode,
29 response::{IntoResponse, Response},
30 };
31
32 use crate::error::{ApiError, Error, GenericErrorCode};
33
34 pub struct GraphQLRejection(pub ParseRequestError);
36
37 impl IntoResponse for GraphQLRejection {
38 fn into_response(self) -> Response {
39 match self.0 {
40 ParseRequestError::PayloadTooLarge => {
41 tracing::warn!("[413 Payload Too Large] Received a GraphQL request with a payload too large");
42 ApiError::new(StatusCode::PAYLOAD_TOO_LARGE, "Payload too large").into_response()
43 }
44 bad_request => ApiError::from_err(Error::new(GenericErrorCode::BadRequest).with_source(bad_request))
45 .into_response(),
46 }
47 }
48 }
49
50 impl From<ParseRequestError> for GraphQLRejection {
51 fn from(err: ParseRequestError) -> Self {
52 GraphQLRejection(err)
53 }
54 }
55}
56
57impl<S, R> FromRequest<S> for GraphQLRequest<R>
58where
59 S: Send + Sync,
60 MultipartOptions: FromRef<S>,
61 R: IntoResponse + From<ParseRequestError>,
62{
63 type Rejection = R;
64
65 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
66 Ok(GraphQLRequest(
67 GraphQLBatchRequest::<R>::from_request(req, state)
68 .await?
69 .0
70 .into_single()?,
71 PhantomData,
72 ))
73 }
74}
75
76pub struct GraphQLBatchRequest<R = rejection::GraphQLRejection>(pub async_graphql::BatchRequest, PhantomData<R>);
78
79impl<R> GraphQLBatchRequest<R> {
80 #[must_use]
82 pub fn into_inner(self) -> async_graphql::BatchRequest {
83 self.0
84 }
85}
86
87impl<S, R> FromRequest<S> for GraphQLBatchRequest<R>
88where
89 S: Send + Sync,
90 R: IntoResponse + From<ParseRequestError>,
91 MultipartOptions: FromRef<S>,
92{
93 type Rejection = R;
94
95 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
96 if req.method() == Method::GET {
97 let uri = req.uri();
98 let res = async_graphql::http::parse_query_string(uri.query().unwrap_or_default()).map_err(|err| {
99 ParseRequestError::Io(std::io::Error::new(
100 ErrorKind::Other,
101 format!("failed to parse graphql request from uri query: {}", err),
102 ))
103 });
104 Ok(Self(async_graphql::BatchRequest::Single(res?), PhantomData))
105 } else {
106 let content_type = req
107 .headers()
108 .get(http::header::CONTENT_TYPE)
109 .and_then(|value| value.to_str().ok())
110 .map(ToString::to_string);
111 let body_stream = req
112 .into_body()
113 .into_data_stream()
114 .map_err(|err| std::io::Error::new(ErrorKind::Other, err.to_string()));
115 let body_reader = tokio_util::io::StreamReader::new(body_stream).compat();
116 Ok(Self(
117 async_graphql::http::receive_batch_body(content_type, body_reader, FromRef::from_ref(state)).await?,
118 PhantomData,
119 ))
120 }
121 }
122}