graphql_starter/graphql/
extract.rs1use std::{io::ErrorKind, marker::PhantomData};
4
5use async_graphql::{futures_util::TryStreamExt, http::MultipartOptions, ParseRequestError};
6use axum::{
7 extract::{FromRef, FromRequest, Request},
8 http::{self, Method},
9 response::IntoResponse,
10};
11use tokio_util::compat::TokioAsyncReadCompatExt;
12
13pub struct GraphQLRequest<R = rejection::GraphQLRejection>(pub async_graphql::Request, PhantomData<R>);
15
16impl<R> GraphQLRequest<R> {
17 #[must_use]
19 pub fn into_inner(self) -> async_graphql::Request {
20 self.0
21 }
22}
23
24pub mod rejection {
26 use async_graphql::ParseRequestError;
27 use axum::{
28 http::StatusCode,
29 response::{IntoResponse, Response},
30 };
31
32 use crate::error::ApiError;
33
34 pub struct GraphQLRejection(pub ParseRequestError);
36
37 impl IntoResponse for GraphQLRejection {
38 fn into_response(self) -> Response {
39 match self.0 {
40 ParseRequestError::PayloadTooLarge => {
41 tracing::warn!("[413 Payload Too Large] Received a GraphQL request with a payload too large");
42 ApiError::new(StatusCode::PAYLOAD_TOO_LARGE, "Payload too large").into_response()
43 }
44 bad_request => {
45 let msg = bad_request.to_string();
46 tracing::warn!("[400 Bad Request] {msg}");
47 ApiError::new(StatusCode::BAD_REQUEST, msg).into_response()
48 }
49 }
50 }
51 }
52
53 impl From<ParseRequestError> for GraphQLRejection {
54 fn from(err: ParseRequestError) -> Self {
55 GraphQLRejection(err)
56 }
57 }
58}
59
60impl<S, R> FromRequest<S> for GraphQLRequest<R>
61where
62 S: Send + Sync,
63 MultipartOptions: FromRef<S>,
64 R: IntoResponse + From<ParseRequestError>,
65{
66 type Rejection = R;
67
68 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
69 Ok(GraphQLRequest(
70 GraphQLBatchRequest::<R>::from_request(req, state)
71 .await?
72 .0
73 .into_single()?,
74 PhantomData,
75 ))
76 }
77}
78
79pub struct GraphQLBatchRequest<R = rejection::GraphQLRejection>(pub async_graphql::BatchRequest, PhantomData<R>);
81
82impl<R> GraphQLBatchRequest<R> {
83 #[must_use]
85 pub fn into_inner(self) -> async_graphql::BatchRequest {
86 self.0
87 }
88}
89
90impl<S, R> FromRequest<S> for GraphQLBatchRequest<R>
91where
92 S: Send + Sync,
93 R: IntoResponse + From<ParseRequestError>,
94 MultipartOptions: FromRef<S>,
95{
96 type Rejection = R;
97
98 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
99 if req.method() == Method::GET {
100 let uri = req.uri();
101 let res = async_graphql::http::parse_query_string(uri.query().unwrap_or_default()).map_err(|err| {
102 ParseRequestError::Io(std::io::Error::new(
103 ErrorKind::Other,
104 format!("failed to parse graphql request from uri query: {}", err),
105 ))
106 });
107 Ok(Self(async_graphql::BatchRequest::Single(res?), PhantomData))
108 } else {
109 let content_type = req
110 .headers()
111 .get(http::header::CONTENT_TYPE)
112 .and_then(|value| value.to_str().ok())
113 .map(ToString::to_string);
114 let body_stream = req
115 .into_body()
116 .into_data_stream()
117 .map_err(|err| std::io::Error::new(ErrorKind::Other, err.to_string()));
118 let body_reader = tokio_util::io::StreamReader::new(body_stream).compat();
119 Ok(Self(
120 async_graphql::http::receive_batch_body(content_type, body_reader, FromRef::from_ref(state)).await?,
121 PhantomData,
122 ))
123 }
124 }
125}