async_graphql_axum/
extract.rs1use std::{io::ErrorKind, marker::PhantomData};
2
3use async_graphql::{futures_util::TryStreamExt, http::MultipartOptions, ParseRequestError};
4use axum::{
5 extract::{FromRequest, Request},
6 http::{self, Method},
7 response::IntoResponse,
8};
9use tokio_util::compat::TokioAsyncReadCompatExt;
10
11pub struct GraphQLRequest<R = rejection::GraphQLRejection>(
13 pub async_graphql::Request,
14 PhantomData<R>,
15);
16
17impl<R> GraphQLRequest<R> {
18 #[must_use]
20 pub fn into_inner(self) -> async_graphql::Request {
21 self.0
22 }
23}
24
25pub mod rejection {
27 use async_graphql::ParseRequestError;
28 use axum::{
29 body::Body,
30 http,
31 http::StatusCode,
32 response::{IntoResponse, Response},
33 };
34
35 pub struct GraphQLRejection(pub ParseRequestError);
37
38 impl IntoResponse for GraphQLRejection {
39 fn into_response(self) -> Response {
40 match self.0 {
41 ParseRequestError::PayloadTooLarge => http::Response::builder()
42 .status(StatusCode::PAYLOAD_TOO_LARGE)
43 .body(Body::empty())
44 .unwrap(),
45 bad_request => http::Response::builder()
46 .status(StatusCode::BAD_REQUEST)
47 .body(Body::from(format!("{:?}", bad_request)))
48 .unwrap(),
49 }
50 }
51 }
52
53 impl From<ParseRequestError> for GraphQLRejection {
54 fn from(err: ParseRequestError) -> Self {
55 GraphQLRejection(err)
56 }
57 }
58}
59
60impl<S, R> FromRequest<S> for GraphQLRequest<R>
61where
62 S: Send + Sync,
63 R: IntoResponse + From<ParseRequestError>,
64{
65 type Rejection = R;
66
67 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
68 Ok(GraphQLRequest(
69 GraphQLBatchRequest::<R>::from_request(req, state)
70 .await?
71 .0
72 .into_single()?,
73 PhantomData,
74 ))
75 }
76}
77
78pub struct GraphQLBatchRequest<R = rejection::GraphQLRejection>(
80 pub async_graphql::BatchRequest,
81 PhantomData<R>,
82);
83
84impl<R> GraphQLBatchRequest<R> {
85 #[must_use]
87 pub fn into_inner(self) -> async_graphql::BatchRequest {
88 self.0
89 }
90}
91
92impl<S, R> FromRequest<S> for GraphQLBatchRequest<R>
93where
94 S: Send + Sync,
95 R: IntoResponse + From<ParseRequestError>,
96{
97 type Rejection = R;
98
99 async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
100 if req.method() == Method::GET {
101 let uri = req.uri();
102 let res = async_graphql::http::parse_query_string(uri.query().unwrap_or_default())
103 .map_err(|err| {
104 ParseRequestError::Io(std::io::Error::new(
105 ErrorKind::Other,
106 format!("failed to parse graphql request from uri query: {}", err),
107 ))
108 });
109 Ok(Self(async_graphql::BatchRequest::Single(res?), PhantomData))
110 } else {
111 let content_type = req
112 .headers()
113 .get(http::header::CONTENT_TYPE)
114 .and_then(|value| value.to_str().ok())
115 .map(ToString::to_string);
116 let body_stream = req
117 .into_body()
118 .into_data_stream()
119 .map_err(|err| std::io::Error::new(ErrorKind::Other, err.to_string()));
120 let body_reader = tokio_util::io::StreamReader::new(body_stream).compat();
121 Ok(Self(
122 async_graphql::http::receive_batch_body(
123 content_type,
124 body_reader,
125 MultipartOptions::default(),
126 )
127 .await?,
128 PhantomData,
129 ))
130 }
131 }
132}