async_graphql_axum_wasi/
extract.rs

1use std::{io::ErrorKind, marker::PhantomData};
2
3use async_graphql::{futures_util::TryStreamExt, http::MultipartOptions, ParseRequestError};
4use axum::{
5    extract::{BodyStream, FromRequest},
6    http::{self, Method, Request},
7    response::IntoResponse,
8    BoxError,
9};
10use bytes::Bytes;
11use tokio_util::compat::TokioAsyncReadCompatExt;
12
13/// Extractor for GraphQL request.
14pub struct GraphQLRequest<R = rejection::GraphQLRejection>(
15    pub async_graphql::Request,
16    PhantomData<R>,
17);
18
19impl<R> GraphQLRequest<R> {
20    /// Unwraps the value to `async_graphql::Request`.
21    #[must_use]
22    pub fn into_inner(self) -> async_graphql::Request {
23        self.0
24    }
25}
26
27/// Rejection response types.
28pub mod rejection {
29    use async_graphql::ParseRequestError;
30    use axum::{
31        body::{boxed, Body, BoxBody},
32        http,
33        http::StatusCode,
34        response::IntoResponse,
35    };
36
37    /// Rejection used for [`GraphQLRequest`](GraphQLRequest).
38    pub struct GraphQLRejection(pub ParseRequestError);
39
40    impl IntoResponse for GraphQLRejection {
41        fn into_response(self) -> http::Response<BoxBody> {
42            match self.0 {
43                ParseRequestError::PayloadTooLarge => http::Response::builder()
44                    .status(StatusCode::PAYLOAD_TOO_LARGE)
45                    .body(boxed(Body::empty()))
46                    .unwrap(),
47                bad_request => http::Response::builder()
48                    .status(StatusCode::BAD_REQUEST)
49                    .body(boxed(Body::from(format!("{:?}", bad_request))))
50                    .unwrap(),
51            }
52        }
53    }
54
55    impl From<ParseRequestError> for GraphQLRejection {
56        fn from(err: ParseRequestError) -> Self {
57            GraphQLRejection(err)
58        }
59    }
60}
61
62#[async_trait::async_trait]
63impl<S, B, R> FromRequest<S, B> for GraphQLRequest<R>
64where
65    B: http_body::Body + Unpin + Send + Sync + 'static,
66    B::Data: Into<Bytes>,
67    B::Error: Into<BoxError>,
68    S: Send + Sync,
69    R: IntoResponse + From<ParseRequestError>,
70{
71    type Rejection = R;
72
73    async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
74        Ok(GraphQLRequest(
75            GraphQLBatchRequest::<R>::from_request(req, state)
76                .await?
77                .0
78                .into_single()?,
79            PhantomData,
80        ))
81    }
82}
83
84/// Extractor for GraphQL batch request.
85pub struct GraphQLBatchRequest<R = rejection::GraphQLRejection>(
86    pub async_graphql::BatchRequest,
87    PhantomData<R>,
88);
89
90impl<R> GraphQLBatchRequest<R> {
91    /// Unwraps the value to `async_graphql::BatchRequest`.
92    #[must_use]
93    pub fn into_inner(self) -> async_graphql::BatchRequest {
94        self.0
95    }
96}
97
98#[async_trait::async_trait]
99impl<S, B, R> FromRequest<S, B> for GraphQLBatchRequest<R>
100where
101    B: http_body::Body + Unpin + Send + Sync + 'static,
102    B::Data: Into<Bytes>,
103    B::Error: Into<BoxError>,
104    S: Send + Sync,
105    R: IntoResponse + From<ParseRequestError>,
106{
107    type Rejection = R;
108
109    async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
110        if let (&Method::GET, uri) = (req.method(), req.uri()) {
111            let res = async_graphql::http::parse_query_string(uri.query().unwrap_or_default())
112                .map_err(|err| {
113                    ParseRequestError::Io(std::io::Error::new(
114                        ErrorKind::Other,
115                        format!("failed to parse graphql request from uri query: {}", err),
116                    ))
117                });
118            Ok(Self(async_graphql::BatchRequest::Single(res?), PhantomData))
119        } else {
120            let content_type = req
121                .headers()
122                .get(http::header::CONTENT_TYPE)
123                .and_then(|value| value.to_str().ok())
124                .map(ToString::to_string);
125            let body_stream = BodyStream::from_request(req, state)
126                .await
127                .map_err(|_| {
128                    ParseRequestError::Io(std::io::Error::new(
129                        ErrorKind::Other,
130                        "body has been taken by another extractor".to_string(),
131                    ))
132                })?
133                .map_err(|err| std::io::Error::new(ErrorKind::Other, err.to_string()));
134            let body_reader = tokio_util::io::StreamReader::new(body_stream).compat();
135            Ok(Self(
136                async_graphql::http::receive_batch_body(
137                    content_type,
138                    body_reader,
139                    MultipartOptions::default(),
140                )
141                .await?,
142                PhantomData,
143            ))
144        }
145    }
146}